take care of gradient accumulation automatically for researchers, by passing in a max_batch_size on the decoder or diffusion prior trainer forward

This commit is contained in:
Phil Wang
2022-05-14 17:04:09 -07:00
parent b494ed81d4
commit b0cd5f24b6
3 changed files with 85 additions and 16 deletions

View File

@@ -732,8 +732,8 @@ clip = CLIP(
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()
# decoder (with unet)
@@ -774,7 +774,12 @@ decoder_trainer = DecoderTrainer(
)
for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
loss = decoder_trainer(
images,
text = text,
unet_number = unet_number, # which unet to train on
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
)
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average

View File

@@ -1,6 +1,8 @@
import time
import copy
from math import ceil
from functools import partial
from collections.abc import Iterable
import torch
from torch import nn
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
@@ -40,6 +45,46 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# gradient accumulation functions
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
batch_size = len(x)
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
all_args = (x, *args, *kwargs.values())
len_all_args = len(all_args)
split_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_index], chunked_all_args[split_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
yield chunk_size, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
@@ -208,15 +253,25 @@ class DiffusionPriorTrainer(nn.Module):
def forward(
self,
x,
*args,
divisor = 1,
max_batch_size = None,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
scaled_loss = self.scaler.scale(loss / divisor)
scaled_loss.backward()
return loss.item()
batch_size = x.shape[0]
total_samples = 0
total_loss = 0.
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
total_loss += loss.item() * chunk_size
total_samples += chunk_size
self.scaler.scale(loss * (chunk_size / batch_size)).backward()
return total_loss / total_samples
# decoder trainer
@@ -327,11 +382,20 @@ class DecoderTrainer(nn.Module):
x,
*,
unet_number,
divisor = 1,
max_batch_size = None,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
scaled_loss = self.scale(loss / divisor, unet_number = unet_number)
scaled_loss.backward()
return loss.item()
batch_size = x.shape[0]
total_samples = 0
total_loss = 0.
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
total_loss += loss.item() * chunk_size
total_samples += chunk_size
self.scale(loss * (chunk_size / batch_size), unet_number = unet_number).backward()
return total_loss / total_samples

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.24',
version = '0.2.26',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',