mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
make sure gradient accumulation feature works even if all arguments passed in are keyword arguments
This commit is contained in:
@@ -814,8 +814,8 @@ clip = CLIP(
|
|||||||
|
|
||||||
# mock data
|
# mock data
|
||||||
|
|
||||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
text = torch.randint(0, 49408, (32, 256)).cuda()
|
||||||
images = torch.randn(4, 3, 256, 256).cuda()
|
images = torch.randn(32, 3, 256, 256).cuda()
|
||||||
|
|
||||||
# prior networks (with transformer)
|
# prior networks (with transformer)
|
||||||
|
|
||||||
@@ -842,7 +842,7 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
|
|||||||
ema_update_every = 10,
|
ema_update_every = 10,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = diffusion_prior_trainer(text, images)
|
loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
|
||||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||||
|
|
||||||
# after much of the above three lines in a loop
|
# after much of the above three lines in a loop
|
||||||
|
|||||||
@@ -66,15 +66,24 @@ def split(t, split_size = None):
|
|||||||
|
|
||||||
return TypeError
|
return TypeError
|
||||||
|
|
||||||
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
|
def find_first(cond, arr):
|
||||||
batch_size = len(x)
|
for el in arr:
|
||||||
|
if cond(el):
|
||||||
|
return el
|
||||||
|
return None
|
||||||
|
|
||||||
|
def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
||||||
|
all_args = (*args, *kwargs.values())
|
||||||
|
len_all_args = len(all_args)
|
||||||
|
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
|
||||||
|
assert exists(first_tensor)
|
||||||
|
|
||||||
|
batch_size = len(first_tensor)
|
||||||
split_size = default(split_size, batch_size)
|
split_size = default(split_size, batch_size)
|
||||||
chunk_size = ceil(batch_size / split_size)
|
chunk_size = ceil(batch_size / split_size)
|
||||||
|
|
||||||
dict_len = len(kwargs)
|
dict_len = len(kwargs)
|
||||||
dict_keys = kwargs.keys()
|
dict_keys = kwargs.keys()
|
||||||
all_args = (x, *args, *kwargs.values())
|
|
||||||
len_all_args = len(all_args)
|
|
||||||
split_kwargs_index = len_all_args - dict_len
|
split_kwargs_index = len_all_args - dict_len
|
||||||
|
|
||||||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
|
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
|
||||||
@@ -258,14 +267,13 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
|
||||||
*args,
|
*args,
|
||||||
max_batch_size = None,
|
max_batch_size = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
total_loss = 0.
|
total_loss = 0.
|
||||||
|
|
||||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||||
with autocast(enabled = self.amp):
|
with autocast(enabled = self.amp):
|
||||||
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||||
loss = loss * chunk_size_frac
|
loss = loss * chunk_size_frac
|
||||||
@@ -385,15 +393,14 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
*args,
|
||||||
*,
|
|
||||||
unet_number,
|
unet_number,
|
||||||
max_batch_size = None,
|
max_batch_size = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
total_loss = 0.
|
total_loss = 0.
|
||||||
|
|
||||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||||
with autocast(enabled = self.amp):
|
with autocast(enabled = self.amp):
|
||||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||||
loss = loss * chunk_size_frac
|
loss = loss * chunk_size_frac
|
||||||
|
|||||||
Reference in New Issue
Block a user