make sure gradient accumulation feature works even if all arguments passed in are keyword arguments

This commit is contained in:
Phil Wang
2022-05-15 11:16:16 -07:00
parent 74f222596a
commit 68e7d2f241
3 changed files with 20 additions and 13 deletions

View File

@@ -814,8 +814,8 @@ clip = CLIP(
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()
# prior networks (with transformer)
@@ -842,7 +842,7 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
ema_update_every = 10,
)
loss = diffusion_prior_trainer(text, images)
loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop

View File

@@ -66,15 +66,24 @@ def split(t, split_size = None):
return TypeError
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
batch_size = len(x)
def find_first(cond, arr):
for el in arr:
if cond(el):
return el
return None
def split_args_and_kwargs(*args, split_size = None, **kwargs):
all_args = (*args, *kwargs.values())
len_all_args = len(all_args)
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
assert exists(first_tensor)
batch_size = len(first_tensor)
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
all_args = (x, *args, *kwargs.values())
len_all_args = len(all_args)
split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
@@ -258,14 +267,13 @@ class DiffusionPriorTrainer(nn.Module):
def forward(
self,
x,
*args,
max_batch_size = None,
**kwargs
):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac
@@ -385,15 +393,14 @@ class DecoderTrainer(nn.Module):
def forward(
self,
x,
*,
*args,
unet_number,
max_batch_size = None,
**kwargs
):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.30',
version = '0.2.31',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',