Compare commits

..

1 Commits

2 changed files with 14 additions and 9 deletions

View File

@@ -68,16 +68,20 @@ def split(t, split_size = None):
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
batch_size = len(x)
chunk_size = ceil(batch_size / default(split_size, batch_size))
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
all_args = (x, *args, *kwargs.values())
len_all_args = len(all_args)
split_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:-dict_len], chunked_all_args[-dict_len:]
chunked_args, chunked_kwargs_values = chunked_all_args[:split_index], chunked_all_args[split_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
yield chunk_size, (chunked_args, chunked_kwargs)
@@ -249,22 +253,23 @@ class DiffusionPriorTrainer(nn.Module):
def forward(
self,
x,
*args,
max_batch_size = None,
**kwargs
):
batch_size = x.shape[0]
total_samples = 0
total_loss = 0.
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
total_loss += loss.item() * chunk_size
total_samples += chunk_size
scaled_loss = self.scaler.scale(loss)
scaled_loss.backward()
self.scaler.scale(loss * (chunk_size / batch_size)).backward()
return total_loss / total_samples
@@ -380,6 +385,7 @@ class DecoderTrainer(nn.Module):
max_batch_size = None,
**kwargs
):
batch_size = x.shape[0]
total_samples = 0
total_loss = 0.
@@ -390,7 +396,6 @@ class DecoderTrainer(nn.Module):
total_loss += loss.item() * chunk_size
total_samples += chunk_size
scaled_loss = self.scale(loss, unet_number = unet_number)
scaled_loss.backward()
self.scale(loss * (chunk_size / batch_size), unet_number = unet_number).backward()
return total_loss / total_samples

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.25',
version = '0.2.26',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',