diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index b0b31c7..0e70312 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -1,7 +1,7 @@ import time import copy from math import ceil -from functools import partial +from functools import partial, wraps from collections.abc import Iterable import torch @@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.optimizer import get_optimizer +import numpy as np + # helper functions def exists(val): @@ -45,6 +47,29 @@ def groupby_prefix_and_trim(prefix, d): kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs +# decorators + +def cast_torch_tensor(fn): + @wraps(fn) + def inner(model, *args, **kwargs): + device = kwargs.pop('_device', next(model.parameters()).device) + cast_device = kwargs.pop('_cast_device', True) + + kwargs_keys = kwargs.keys() + all_args = (*args, *kwargs.values()) + split_kwargs_index = len(all_args) - len(kwargs_keys) + all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args)) + + if cast_device: + all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) + + args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] + kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) + + out = fn(model, *args, **kwargs) + return out + return inner + # gradient accumulation functions def split_iterable(it, split_size): @@ -254,10 +279,12 @@ class DiffusionPriorTrainer(nn.Module): self.step += 1 @torch.inference_mode() + @cast_torch_tensor def p_sample_loop(self, *args, **kwargs): return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) @torch.inference_mode() + @cast_torch_tensor def sample(self, *args, **kwargs): return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) @@ -265,6 +292,7 @@ class DiffusionPriorTrainer(nn.Module): def sample_batch_size(self, *args, **kwargs): return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs) + @cast_torch_tensor def forward( self, *args, @@ -377,6 +405,7 @@ class DecoderTrainer(nn.Module): self.step += 1 @torch.no_grad() + @cast_torch_tensor def sample(self, *args, **kwargs): if self.use_ema: trainable_unets = self.decoder.unets @@ -393,6 +422,7 @@ class DecoderTrainer(nn.Module): return output + @cast_torch_tensor def forward( self, *args, diff --git a/setup.py b/setup.py index b6203e4..f2af0ba 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.32', + version = '0.2.34', license='MIT', description = 'DALL-E 2', author = 'Phil Wang', @@ -30,6 +30,7 @@ setup( 'einops-exts>=0.0.3', 'embedding-reader', 'kornia>=0.5.4', + 'numpy', 'pillow', 'resize-right>=0.0.2', 'rotary-embedding-torch',