take care of gradient accumulation automatically for researchers, by passing in a max_batch_size on the decoder or diffusion prior trainer forward

This commit is contained in:
Phil Wang
2022-05-14 17:04:09 -07:00
parent b494ed81d4
commit b0cd5f24b6
3 changed files with 85 additions and 16 deletions

View File

@@ -732,8 +732,8 @@ clip = CLIP(
# mock data # mock data
text = torch.randint(0, 49408, (4, 256)).cuda() text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda() images = torch.randn(32, 3, 256, 256).cuda()
# decoder (with unet) # decoder (with unet)
@@ -774,7 +774,12 @@ decoder_trainer = DecoderTrainer(
) )
for unet_number in (1, 2): for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward loss = decoder_trainer(
images,
text = text,
unet_number = unet_number, # which unet to train on
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
)
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average

View File

@@ -1,6 +1,8 @@
import time import time
import copy import copy
from math import ceil
from functools import partial from functools import partial
from collections.abc import Iterable
import torch import torch
from torch import nn from torch import nn
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
def exists(val): def exists(val):
return val is not None return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1): def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length) return val if isinstance(val, tuple) else ((val,) * length)
@@ -40,6 +45,46 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
# gradient accumulation functions
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
batch_size = len(x)
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
all_args = (x, *args, *kwargs.values())
len_all_args = len(all_args)
split_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_index], chunked_all_args[split_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
yield chunk_size, (chunked_args, chunked_kwargs)
# print helpers # print helpers
def print_ribbon(s, symbol = '=', repeat = 40): def print_ribbon(s, symbol = '=', repeat = 40):
@@ -208,15 +253,25 @@ class DiffusionPriorTrainer(nn.Module):
def forward( def forward(
self, self,
x,
*args, *args,
divisor = 1, max_batch_size = None,
**kwargs **kwargs
): ):
with autocast(enabled = self.amp): batch_size = x.shape[0]
loss = self.diffusion_prior(*args, **kwargs) total_samples = 0
scaled_loss = self.scaler.scale(loss / divisor) total_loss = 0.
scaled_loss.backward()
return loss.item() for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
total_loss += loss.item() * chunk_size
total_samples += chunk_size
self.scaler.scale(loss * (chunk_size / batch_size)).backward()
return total_loss / total_samples
# decoder trainer # decoder trainer
@@ -327,11 +382,20 @@ class DecoderTrainer(nn.Module):
x, x,
*, *,
unet_number, unet_number,
divisor = 1, max_batch_size = None,
**kwargs **kwargs
): ):
with autocast(enabled = self.amp): batch_size = x.shape[0]
loss = self.decoder(x, unet_number = unet_number, **kwargs) total_samples = 0
scaled_loss = self.scale(loss / divisor, unet_number = unet_number) total_loss = 0.
scaled_loss.backward()
return loss.item() for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
total_loss += loss.item() * chunk_size
total_samples += chunk_size
self.scale(loss * (chunk_size / batch_size), unet_number = unet_number).backward()
return total_loss / total_samples

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.24', version = '0.2.26',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',