mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b494ed81d4 | ||
|
|
ff3474f05c |
@@ -775,7 +775,6 @@ decoder_trainer = DecoderTrainer(
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
||||
loss.backward()
|
||||
|
||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||
|
||||
@@ -839,7 +838,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
|
||||
)
|
||||
|
||||
loss = diffusion_prior_trainer(text, images)
|
||||
loss.backward()
|
||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||
|
||||
# after much of the above three lines in a loop
|
||||
@@ -1017,6 +1015,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -1163,6 +1163,7 @@ class CrossAttention(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
norm_context = False
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -1172,7 +1173,7 @@ class CrossAttention(nn.Module):
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.norm_context = LayerNorm(context_dim)
|
||||
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -1378,6 +1379,9 @@ class Unet(nn.Module):
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||
|
||||
# text encoding conditioning (optional)
|
||||
|
||||
self.text_to_cond = None
|
||||
@@ -1593,6 +1597,11 @@ class Unet(nn.Module):
|
||||
|
||||
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
||||
|
||||
# normalize conditioning tokens
|
||||
|
||||
c = self.norm_cond(c)
|
||||
mid_c = self.norm_mid_cond(mid_c)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
hiddens = []
|
||||
|
||||
@@ -214,7 +214,9 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*args, **kwargs)
|
||||
return self.scaler.scale(loss / divisor)
|
||||
scaled_loss = self.scaler.scale(loss / divisor)
|
||||
scaled_loss.backward()
|
||||
return loss.item()
|
||||
|
||||
# decoder trainer
|
||||
|
||||
@@ -330,4 +332,6 @@ class DecoderTrainer(nn.Module):
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||
return self.scale(loss / divisor, unet_number = unet_number)
|
||||
scaled_loss = self.scale(loss / divisor, unet_number = unet_number)
|
||||
scaled_loss.backward()
|
||||
return loss.item()
|
||||
|
||||
Reference in New Issue
Block a user