Compare commits

..

2 Commits

Author SHA1 Message Date
Phil Wang
b22ccd9dd0 trainer classes now takes care of auto-casting numpy to torch tensors, and setting correct device based on model parameter devices 2022-05-15 15:21:49 -07:00
Phil Wang
0f0011caf0 todo 2022-05-15 14:28:35 -07:00
3 changed files with 29 additions and 2 deletions

View File

@@ -1013,6 +1013,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well - [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly - [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
## Citations ## Citations

View File

@@ -1,7 +1,7 @@
import time import time
import copy import copy
from math import ceil from math import ceil
from functools import partial from functools import partial, wraps
from collections.abc import Iterable from collections.abc import Iterable
import torch import torch
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
import numpy as np
# helper functions # helper functions
def exists(val): def exists(val):
@@ -45,6 +47,24 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
# decorators
def cast_torch_tensor(fn):
@wraps(fn)
def inner(model, *args, **kwargs):
device = next(model.parameters()).device
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
split_kwargs_index = len(all_args) - len(kwargs_keys)
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
out = fn(model, *args, **kwargs)
return out
return inner
# gradient accumulation functions # gradient accumulation functions
def split_iterable(it, split_size): def split_iterable(it, split_size):
@@ -254,10 +274,12 @@ class DiffusionPriorTrainer(nn.Module):
self.step += 1 self.step += 1
@torch.inference_mode() @torch.inference_mode()
@cast_torch_tensor
def p_sample_loop(self, *args, **kwargs): def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode() @torch.inference_mode()
@cast_torch_tensor
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@@ -265,6 +287,7 @@ class DiffusionPriorTrainer(nn.Module):
def sample_batch_size(self, *args, **kwargs): def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs) return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
@cast_torch_tensor
def forward( def forward(
self, self,
*args, *args,
@@ -377,6 +400,7 @@ class DecoderTrainer(nn.Module):
self.step += 1 self.step += 1
@torch.no_grad() @torch.no_grad()
@cast_torch_tensor
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
if self.use_ema: if self.use_ema:
trainable_unets = self.decoder.unets trainable_unets = self.decoder.unets
@@ -393,6 +417,7 @@ class DecoderTrainer(nn.Module):
return output return output
@cast_torch_tensor
def forward( def forward(
self, self,
*args, *args,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.32', version = '0.2.33',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -30,6 +30,7 @@ setup(
'einops-exts>=0.0.3', 'einops-exts>=0.0.3',
'embedding-reader', 'embedding-reader',
'kornia>=0.5.4', 'kornia>=0.5.4',
'numpy',
'pillow', 'pillow',
'resize-right>=0.0.2', 'resize-right>=0.0.2',
'rotary-embedding-torch', 'rotary-embedding-torch',