mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
trainer classes now takes care of auto-casting numpy to torch tensors, and setting correct device based on model parameter devices
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
import copy
|
||||
from math import ceil
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
import numpy as np
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
@@ -45,6 +47,29 @@ def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# decorators
|
||||
|
||||
def cast_torch_tensor(fn):
|
||||
@wraps(fn)
|
||||
def inner(model, *args, **kwargs):
|
||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
||||
cast_device = kwargs.pop('_cast_device', True)
|
||||
|
||||
kwargs_keys = kwargs.keys()
|
||||
all_args = (*args, *kwargs.values())
|
||||
split_kwargs_index = len(all_args) - len(kwargs_keys)
|
||||
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
|
||||
|
||||
if cast_device:
|
||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||
|
||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||
|
||||
out = fn(model, *args, **kwargs)
|
||||
return out
|
||||
return inner
|
||||
|
||||
# gradient accumulation functions
|
||||
|
||||
def split_iterable(it, split_size):
|
||||
@@ -254,10 +279,12 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
self.step += 1
|
||||
|
||||
@torch.inference_mode()
|
||||
@cast_torch_tensor
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
@cast_torch_tensor
|
||||
def sample(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
|
||||
@@ -265,6 +292,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
@@ -377,6 +405,7 @@ class DecoderTrainer(nn.Module):
|
||||
self.step += 1
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
def sample(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
trainable_unets = self.decoder.unets
|
||||
@@ -393,6 +422,7 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
|
||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.2.32',
|
||||
version = '0.2.34',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -30,6 +30,7 @@ setup(
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'numpy',
|
||||
'pillow',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
|
||||
Reference in New Issue
Block a user