Compare commits

...

2 Commits

Author SHA1 Message Date
Phil Wang
99778e12de trainer classes now takes care of auto-casting numpy to torch tensors, and setting correct device based on model parameter devices 2022-05-15 15:25:45 -07:00
Phil Wang
0f0011caf0 todo 2022-05-15 14:28:35 -07:00
3 changed files with 34 additions and 2 deletions

View File

@@ -1013,6 +1013,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
## Citations

View File

@@ -1,7 +1,7 @@
import time
import copy
from math import ceil
from functools import partial
from functools import partial, wraps
from collections.abc import Iterable
import torch
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
import numpy as np
# helper functions
def exists(val):
@@ -45,6 +47,29 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# decorators
def cast_torch_tensor(fn):
@wraps(fn)
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
split_kwargs_index = len(all_args) - len(kwargs_keys)
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
out = fn(model, *args, **kwargs)
return out
return inner
# gradient accumulation functions
def split_iterable(it, split_size):
@@ -254,10 +279,12 @@ class DiffusionPriorTrainer(nn.Module):
self.step += 1
@torch.inference_mode()
@cast_torch_tensor
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode()
@cast_torch_tensor
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@@ -265,6 +292,7 @@ class DiffusionPriorTrainer(nn.Module):
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
@cast_torch_tensor
def forward(
self,
*args,
@@ -377,6 +405,7 @@ class DecoderTrainer(nn.Module):
self.step += 1
@torch.no_grad()
@cast_torch_tensor
def sample(self, *args, **kwargs):
if self.use_ema:
trainable_unets = self.decoder.unets
@@ -393,6 +422,7 @@ class DecoderTrainer(nn.Module):
return output
@cast_torch_tensor
def forward(
self,
*args,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.32',
version = '0.2.34',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -30,6 +30,7 @@ setup(
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'numpy',
'pillow',
'resize-right>=0.0.2',
'rotary-embedding-torch',