diff --git a/README.md b/README.md index aae849e..d3ec4cb 100644 --- a/README.md +++ b/README.md @@ -732,8 +732,8 @@ clip = CLIP( # mock data -text = torch.randint(0, 49408, (4, 256)).cuda() -images = torch.randn(4, 3, 256, 256).cuda() +text = torch.randint(0, 49408, (32, 256)).cuda() +images = torch.randn(32, 3, 256, 256).cuda() # decoder (with unet) @@ -774,7 +774,12 @@ decoder_trainer = DecoderTrainer( ) for unet_number in (1, 2): - loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward + loss = decoder_trainer( + images, + text = text, + unet_number = unet_number, # which unet to train on + max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times + ) decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index 244682c..b587170 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -1,6 +1,8 @@ import time import copy +from math import ceil from functools import partial +from collections.abc import Iterable import torch from torch import nn @@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer def exists(val): return val is not None +def default(val, d): + return val if exists(val) else d + def cast_tuple(val, length = 1): return val if isinstance(val, tuple) else ((val,) * length) @@ -40,6 +45,46 @@ def groupby_prefix_and_trim(prefix, d): kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs +# gradient accumulation functions + +def split_iterable(it, split_size): + accum = [] + for ind in range(ceil(len(it) / split_size)): + start_index = ind * split_size + accum.append(it[start_index: (start_index + split_size)]) + return accum + +def split(t, split_size = None): + if not exists(split_size): + return t + + if isinstance(t, torch.Tensor): + return t.split(split_size, dim = 0) + + if isinstance(t, Iterable): + return split_iterable(t, split_size) + + return TypeError + +def split_args_and_kwargs(x, *args, split_size = None, **kwargs): + batch_size = len(x) + split_size = default(split_size, batch_size) + chunk_size = ceil(batch_size / split_size) + + dict_len = len(kwargs) + dict_keys = kwargs.keys() + all_args = (x, *args, *kwargs.values()) + len_all_args = len(all_args) + split_index = len_all_args - dict_len + + split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args] + chunk_sizes = tuple(map(len, split_all_args[0])) + + for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): + chunked_args, chunked_kwargs_values = chunked_all_args[:split_index], chunked_all_args[split_index:] + chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values))) + yield chunk_size, (chunked_args, chunked_kwargs) + # print helpers def print_ribbon(s, symbol = '=', repeat = 40): @@ -208,15 +253,25 @@ class DiffusionPriorTrainer(nn.Module): def forward( self, + x, *args, - divisor = 1, + max_batch_size = None, **kwargs ): - with autocast(enabled = self.amp): - loss = self.diffusion_prior(*args, **kwargs) - scaled_loss = self.scaler.scale(loss / divisor) - scaled_loss.backward() - return loss.item() + batch_size = x.shape[0] + total_samples = 0 + total_loss = 0. + + for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs): + with autocast(enabled = self.amp): + loss = self.diffusion_prior(*chunked_args, **chunked_kwargs) + + total_loss += loss.item() * chunk_size + total_samples += chunk_size + + self.scaler.scale(loss * (chunk_size / batch_size)).backward() + + return total_loss / total_samples # decoder trainer @@ -327,11 +382,20 @@ class DecoderTrainer(nn.Module): x, *, unet_number, - divisor = 1, + max_batch_size = None, **kwargs ): - with autocast(enabled = self.amp): - loss = self.decoder(x, unet_number = unet_number, **kwargs) - scaled_loss = self.scale(loss / divisor, unet_number = unet_number) - scaled_loss.backward() - return loss.item() + batch_size = x.shape[0] + total_samples = 0 + total_loss = 0. + + for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs): + with autocast(enabled = self.amp): + loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) + + total_loss += loss.item() * chunk_size + total_samples += chunk_size + + self.scale(loss * (chunk_size / batch_size), unet_number = unet_number).backward() + + return total_loss / total_samples diff --git a/setup.py b/setup.py index 88c8d1f..b74b87b 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.24', + version = '0.2.26', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',