diff --git a/README.md b/README.md index 06c5755..61239e1 100644 --- a/README.md +++ b/README.md @@ -814,8 +814,8 @@ clip = CLIP( # mock data -text = torch.randint(0, 49408, (4, 256)).cuda() -images = torch.randn(4, 3, 256, 256).cuda() +text = torch.randint(0, 49408, (32, 256)).cuda() +images = torch.randn(32, 3, 256, 256).cuda() # prior networks (with transformer) @@ -842,7 +842,7 @@ diffusion_prior_trainer = DiffusionPriorTrainer( ema_update_every = 10, ) -loss = diffusion_prior_trainer(text, images) +loss = diffusion_prior_trainer(text, images, max_batch_size = 4) diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior # after much of the above three lines in a loop diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index c8bb946..80b35aa 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -66,15 +66,24 @@ def split(t, split_size = None): return TypeError -def split_args_and_kwargs(x, *args, split_size = None, **kwargs): - batch_size = len(x) +def find_first(cond, arr): + for el in arr: + if cond(el): + return el + return None + +def split_args_and_kwargs(*args, split_size = None, **kwargs): + all_args = (*args, *kwargs.values()) + len_all_args = len(all_args) + first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args) + assert exists(first_tensor) + + batch_size = len(first_tensor) split_size = default(split_size, batch_size) chunk_size = ceil(batch_size / split_size) dict_len = len(kwargs) dict_keys = kwargs.keys() - all_args = (x, *args, *kwargs.values()) - len_all_args = len(all_args) split_kwargs_index = len_all_args - dict_len split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args] @@ -258,14 +267,13 @@ class DiffusionPriorTrainer(nn.Module): def forward( self, - x, *args, max_batch_size = None, **kwargs ): total_loss = 0. - for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs): + for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): with autocast(enabled = self.amp): loss = self.diffusion_prior(*chunked_args, **chunked_kwargs) loss = loss * chunk_size_frac @@ -385,15 +393,14 @@ class DecoderTrainer(nn.Module): def forward( self, - x, - *, + *args, unet_number, max_batch_size = None, **kwargs ): total_loss = 0. - for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs): + for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): with autocast(enabled = self.amp): loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) loss = loss * chunk_size_frac diff --git a/setup.py b/setup.py index 99c82f1..214487d 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.30', + version = '0.2.31', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',