Compare commits

...

7 Commits

Author SHA1 Message Date
Phil Wang
3afdcdfe86 need to keep track of training steps separately for each unet in decoder trainer 2022-07-05 15:17:59 -07:00
Phil Wang
b9a908ff75 bring in two tricks from the cogview paper for reducing the chances of overflow, for attention and layernorm 2022-07-05 14:27:04 -07:00
Phil Wang
e1fe3089df do bias-less layernorm manually 2022-07-05 13:09:58 -07:00
Phil Wang
6d477d7654 link to dalle2 laion 2022-07-05 11:43:07 -07:00
Phil Wang
531fe4b62f status 2022-07-05 10:46:55 -07:00
Phil Wang
ec5a77fc55 0.15.4 2022-07-02 08:56:34 -07:00
Aidan Dempster
fac63c61bc Fixed variable naming issue (#183) 2022-07-02 08:56:03 -07:00
5 changed files with 43 additions and 21 deletions

View File

@@ -20,18 +20,20 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
<img src="./samples/oxford.png" width="600px" />
<img src="./samples/oxford.png" width="450px" />
*ongoing at 21k steps*
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
- <a href="https://github.com/rom1504">Romain</a> has scaled up training to 800 GPUs with the available scripts without any issues
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
- DALL-E 2 🚧
- DALL-E 2 🚧 - <a href="https://github.com/LAION-AI/dalle2-laion">DALL-E 2 Laion repository</a>
## Appreciation

View File

@@ -490,14 +490,16 @@ class NoiseScheduler(nn.Module):
# diffusion prior
class LayerNorm(nn.Module):
def __init__(self, dim):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
@@ -506,10 +508,10 @@ class ChanLayerNorm(nn.Module):
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
return (x - mean) * (var + self.eps).rsqrt() * self.g
class Residual(nn.Module):
def __init__(self, fn):
@@ -629,10 +631,13 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
rotary_emb = None
rotary_emb = None,
pb_relax_alpha = 32 ** 2
):
super().__init__()
self.scale = dim_head ** -0.5
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.heads = heads
inner_dim = dim_head * heads
@@ -696,6 +701,9 @@ class Attention(nn.Module):
# attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.dropout(attn)
@@ -1210,10 +1218,12 @@ class CrossAttention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
norm_context = False
norm_context = False,
pb_relax_alpha = 32 ** 2
):
super().__init__()
self.scale = dim_head ** -0.5
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.heads = heads
inner_dim = dim_head * heads
@@ -1259,6 +1269,9 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1, dtype = torch.float32)
out = einsum('b h i j, b h j d -> b h i d', attn, v)

View File

@@ -346,17 +346,17 @@ class TrainDecoderConfig(BaseModel):
img_emb_url = data_config.img_embeddings_url
text_emb_url = data_config.text_embeddings_url
if using_text_encodings:
if using_text_embeddings:
# Then we need some way to get the embeddings
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
if using_clip:
if using_text_encodings:
if using_text_embeddings:
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
else:
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
if text_emb_url:
assert using_text_encodings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return values

View File

@@ -6,6 +6,7 @@ from functools import partial, wraps
from collections.abc import Iterable
import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import autocast, GradScaler
@@ -474,7 +475,7 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
@@ -491,7 +492,7 @@ class DecoderTrainer(nn.Module):
save_obj = dict(
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
version = __version__,
step = self.step.item(),
steps = self.steps.cpu(),
**kwargs
)
@@ -510,7 +511,7 @@ class DecoderTrainer(nn.Module):
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
self.steps.copy_(loaded_obj['steps'])
if only_model:
return loaded_obj
@@ -539,6 +540,12 @@ class DecoderTrainer(nn.Module):
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def increment_step(self, unet_number):
assert 1 <= unet_number <= self.num_unets
unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
@@ -557,7 +564,7 @@ class DecoderTrainer(nn.Module):
ema_unet = self.ema_unets[index]
ema_unet.update()
self.step += 1
self.increment_step(unet_number)
@torch.no_grad()
@cast_torch_tensor

View File

@@ -1 +1 @@
__version__ = '0.15.3'
__version__ = '0.16.3'