mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 14:54:21 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
708638d3d9 | ||
|
|
b494ed81d4 |
13
README.md
13
README.md
@@ -732,8 +732,8 @@ clip = CLIP(
|
|||||||
|
|
||||||
# mock data
|
# mock data
|
||||||
|
|
||||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
text = torch.randint(0, 49408, (32, 256)).cuda()
|
||||||
images = torch.randn(4, 3, 256, 256).cuda()
|
images = torch.randn(32, 3, 256, 256).cuda()
|
||||||
|
|
||||||
# decoder (with unet)
|
# decoder (with unet)
|
||||||
|
|
||||||
@@ -774,8 +774,12 @@ decoder_trainer = DecoderTrainer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for unet_number in (1, 2):
|
for unet_number in (1, 2):
|
||||||
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
loss = decoder_trainer(
|
||||||
loss.backward()
|
images,
|
||||||
|
text = text,
|
||||||
|
unet_number = unet_number, # which unet to train on
|
||||||
|
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
|
||||||
|
)
|
||||||
|
|
||||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||||
|
|
||||||
@@ -839,7 +843,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
loss = diffusion_prior_trainer(text, images)
|
loss = diffusion_prior_trainer(text, images)
|
||||||
loss.backward()
|
|
||||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||||
|
|
||||||
# after much of the above three lines in a loop
|
# after much of the above three lines in a loop
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
|
from math import ceil
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
|
|||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
def cast_tuple(val, length = 1):
|
def cast_tuple(val, length = 1):
|
||||||
return val if isinstance(val, tuple) else ((val,) * length)
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
@@ -40,6 +45,42 @@ def groupby_prefix_and_trim(prefix, d):
|
|||||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||||
return kwargs_without_prefix, kwargs
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
# gradient accumulation functions
|
||||||
|
|
||||||
|
def split_iterable(it, split_size):
|
||||||
|
accum = []
|
||||||
|
for ind in range(ceil(len(it) / split_size)):
|
||||||
|
start_index = ind * split_size
|
||||||
|
accum.append(it[start_index: (start_index + split_size)])
|
||||||
|
return accum
|
||||||
|
|
||||||
|
def split(t, split_size = None):
|
||||||
|
if not exists(split_size):
|
||||||
|
return t
|
||||||
|
|
||||||
|
if isinstance(t, torch.Tensor):
|
||||||
|
return t.split(split_size, dim = 0)
|
||||||
|
|
||||||
|
if isinstance(t, Iterable):
|
||||||
|
return split_iterable(t, split_size)
|
||||||
|
|
||||||
|
return TypeError
|
||||||
|
|
||||||
|
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
|
||||||
|
batch_size = len(x)
|
||||||
|
chunk_size = ceil(batch_size / default(split_size, batch_size))
|
||||||
|
|
||||||
|
dict_len = len(kwargs)
|
||||||
|
dict_keys = kwargs.keys()
|
||||||
|
all_args = (x, *args, *kwargs.values())
|
||||||
|
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
|
||||||
|
chunk_sizes = tuple(map(len, split_all_args[0]))
|
||||||
|
|
||||||
|
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
||||||
|
chunked_args, chunked_kwargs_values = chunked_all_args[:-dict_len], chunked_all_args[-dict_len:]
|
||||||
|
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
|
||||||
|
yield chunk_size, (chunked_args, chunked_kwargs)
|
||||||
|
|
||||||
# print helpers
|
# print helpers
|
||||||
|
|
||||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||||
@@ -209,12 +250,23 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
divisor = 1,
|
max_batch_size = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
with autocast(enabled = self.amp):
|
total_samples = 0
|
||||||
loss = self.diffusion_prior(*args, **kwargs)
|
total_loss = 0.
|
||||||
return self.scaler.scale(loss / divisor)
|
|
||||||
|
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||||
|
with autocast(enabled = self.amp):
|
||||||
|
loss = self.diffusion_prior(*args, **kwargs)
|
||||||
|
|
||||||
|
total_loss += loss.item() * chunk_size
|
||||||
|
total_samples += chunk_size
|
||||||
|
|
||||||
|
scaled_loss = self.scaler.scale(loss)
|
||||||
|
scaled_loss.backward()
|
||||||
|
|
||||||
|
return total_loss / total_samples
|
||||||
|
|
||||||
# decoder trainer
|
# decoder trainer
|
||||||
|
|
||||||
@@ -325,9 +377,20 @@ class DecoderTrainer(nn.Module):
|
|||||||
x,
|
x,
|
||||||
*,
|
*,
|
||||||
unet_number,
|
unet_number,
|
||||||
divisor = 1,
|
max_batch_size = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
with autocast(enabled = self.amp):
|
total_samples = 0
|
||||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
total_loss = 0.
|
||||||
return self.scale(loss / divisor, unet_number = unet_number)
|
|
||||||
|
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
|
||||||
|
with autocast(enabled = self.amp):
|
||||||
|
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||||
|
|
||||||
|
total_loss += loss.item() * chunk_size
|
||||||
|
total_samples += chunk_size
|
||||||
|
|
||||||
|
scaled_loss = self.scale(loss, unet_number = unet_number)
|
||||||
|
scaled_loss.backward()
|
||||||
|
|
||||||
|
return total_loss / total_samples
|
||||||
|
|||||||
Reference in New Issue
Block a user