Compare commits

..

3 Commits

4 changed files with 91 additions and 15 deletions

View File

@@ -732,8 +732,8 @@ clip = CLIP(
# mock data # mock data
text = torch.randint(0, 49408, (4, 256)).cuda() text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda() images = torch.randn(32, 3, 256, 256).cuda()
# decoder (with unet) # decoder (with unet)
@@ -774,8 +774,12 @@ decoder_trainer = DecoderTrainer(
) )
for unet_number in (1, 2): for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward loss = decoder_trainer(
loss.backward() images,
text = text,
unet_number = unet_number, # which unet to train on
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
)
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
@@ -839,7 +843,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
) )
loss = diffusion_prior_trainer(text, images) loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop # after much of the above three lines in a loop
@@ -1017,6 +1020,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes - [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
## Citations ## Citations

View File

@@ -1163,6 +1163,7 @@ class CrossAttention(nn.Module):
dim_head = 64, dim_head = 64,
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
norm_context = False
): ):
super().__init__() super().__init__()
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
@@ -1172,7 +1173,7 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim) context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim) self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim) self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -1378,6 +1379,9 @@ class Unet(nn.Module):
Rearrange('b (n d) -> b n d', n = num_image_tokens) Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity() ) if image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim)
# text encoding conditioning (optional) # text encoding conditioning (optional)
self.text_to_cond = None self.text_to_cond = None
@@ -1593,6 +1597,11 @@ class Unet(nn.Module):
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2) mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# normalize conditioning tokens
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# go through the layers of the unet, down and up # go through the layers of the unet, down and up
hiddens = [] hiddens = []

View File

@@ -1,6 +1,8 @@
import time import time
import copy import copy
from math import ceil
from functools import partial from functools import partial
from collections.abc import Iterable
import torch import torch
from torch import nn from torch import nn
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
def exists(val): def exists(val):
return val is not None return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1): def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length) return val if isinstance(val, tuple) else ((val,) * length)
@@ -40,6 +45,42 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
# gradient accumulation functions
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
batch_size = len(x)
chunk_size = ceil(batch_size / default(split_size, batch_size))
dict_len = len(kwargs)
dict_keys = kwargs.keys()
all_args = (x, *args, *kwargs.values())
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:-dict_len], chunked_all_args[-dict_len:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
yield chunk_size, (chunked_args, chunked_kwargs)
# print helpers # print helpers
def print_ribbon(s, symbol = '=', repeat = 40): def print_ribbon(s, symbol = '=', repeat = 40):
@@ -209,12 +250,23 @@ class DiffusionPriorTrainer(nn.Module):
def forward( def forward(
self, self,
*args, *args,
divisor = 1, max_batch_size = None,
**kwargs **kwargs
): ):
with autocast(enabled = self.amp): total_samples = 0
loss = self.diffusion_prior(*args, **kwargs) total_loss = 0.
return self.scaler.scale(loss / divisor)
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
total_loss += loss.item() * chunk_size
total_samples += chunk_size
scaled_loss = self.scaler.scale(loss)
scaled_loss.backward()
return total_loss / total_samples
# decoder trainer # decoder trainer
@@ -325,9 +377,20 @@ class DecoderTrainer(nn.Module):
x, x,
*, *,
unet_number, unet_number,
divisor = 1, max_batch_size = None,
**kwargs **kwargs
): ):
with autocast(enabled = self.amp): total_samples = 0
loss = self.decoder(x, unet_number = unet_number, **kwargs) total_loss = 0.
return self.scale(loss / divisor, unet_number = unet_number)
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
total_loss += loss.item() * chunk_size
total_samples += chunk_size
scaled_loss = self.scale(loss, unet_number = unet_number)
scaled_loss.backward()
return total_loss / total_samples

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.22', version = '0.2.25',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',