mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 21:44:29 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9549bd43b7 | ||
|
|
aee92dba4a | ||
|
|
b0cd5f24b6 |
@@ -68,18 +68,23 @@ def split(t, split_size = None):
|
|||||||
|
|
||||||
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
|
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
|
||||||
batch_size = len(x)
|
batch_size = len(x)
|
||||||
chunk_size = ceil(batch_size / default(split_size, batch_size))
|
split_size = default(split_size, batch_size)
|
||||||
|
chunk_size = ceil(batch_size / split_size)
|
||||||
|
|
||||||
dict_len = len(kwargs)
|
dict_len = len(kwargs)
|
||||||
dict_keys = kwargs.keys()
|
dict_keys = kwargs.keys()
|
||||||
all_args = (x, *args, *kwargs.values())
|
all_args = (x, *args, *kwargs.values())
|
||||||
|
len_all_args = len(all_args)
|
||||||
|
split_kwargs_index = len_all_args - dict_len
|
||||||
|
|
||||||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
|
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
|
||||||
chunk_sizes = tuple(map(len, split_all_args[0]))
|
chunk_sizes = tuple(map(len, split_all_args[0]))
|
||||||
|
|
||||||
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
||||||
chunked_args, chunked_kwargs_values = chunked_all_args[:-dict_len], chunked_all_args[-dict_len:]
|
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
|
||||||
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
|
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
|
||||||
yield chunk_size, (chunked_args, chunked_kwargs)
|
chunk_size_frac = chunk_size / batch_size
|
||||||
|
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||||
|
|
||||||
# print helpers
|
# print helpers
|
||||||
|
|
||||||
@@ -249,24 +254,22 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
x,
|
||||||
*args,
|
*args,
|
||||||
max_batch_size = None,
|
max_batch_size = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
total_samples = 0
|
|
||||||
total_loss = 0.
|
total_loss = 0.
|
||||||
|
|
||||||
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
|
||||||
with autocast(enabled = self.amp):
|
with autocast(enabled = self.amp):
|
||||||
loss = self.diffusion_prior(*args, **kwargs)
|
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||||
|
|
||||||
total_loss += loss.item() * chunk_size
|
loss = loss * chunk_size_frac
|
||||||
total_samples += chunk_size
|
total_loss += loss.item()
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
|
|
||||||
scaled_loss = self.scaler.scale(loss)
|
return total_loss
|
||||||
scaled_loss.backward()
|
|
||||||
|
|
||||||
return total_loss / total_samples
|
|
||||||
|
|
||||||
# decoder trainer
|
# decoder trainer
|
||||||
|
|
||||||
@@ -380,17 +383,14 @@ class DecoderTrainer(nn.Module):
|
|||||||
max_batch_size = None,
|
max_batch_size = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
total_samples = 0
|
|
||||||
total_loss = 0.
|
total_loss = 0.
|
||||||
|
|
||||||
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
|
||||||
with autocast(enabled = self.amp):
|
with autocast(enabled = self.amp):
|
||||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||||
|
|
||||||
total_loss += loss.item() * chunk_size
|
loss = loss * chunk_size_frac
|
||||||
total_samples += chunk_size
|
total_loss += loss.item()
|
||||||
|
self.scale(loss, unet_number = unet_number).backward()
|
||||||
|
|
||||||
scaled_loss = self.scale(loss, unet_number = unet_number)
|
return total_loss
|
||||||
scaled_loss.backward()
|
|
||||||
|
|
||||||
return total_loss / total_samples
|
|
||||||
|
|||||||
Reference in New Issue
Block a user