mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 10:14:19 +01:00
trainer classes now takes care of auto-casting numpy to torch tensors, and setting correct device based on model parameter devices
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from functools import partial
|
from functools import partial, wraps
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
|
|||||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
@@ -45,6 +47,24 @@ def groupby_prefix_and_trim(prefix, d):
|
|||||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||||
return kwargs_without_prefix, kwargs
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
# decorators
|
||||||
|
|
||||||
|
def cast_torch_tensor(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(model, *args, **kwargs):
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
kwargs_keys = kwargs.keys()
|
||||||
|
all_args = (*args, *kwargs.values())
|
||||||
|
split_kwargs_index = len(all_args) - len(kwargs_keys)
|
||||||
|
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
|
||||||
|
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||||
|
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||||
|
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||||
|
|
||||||
|
out = fn(model, *args, **kwargs)
|
||||||
|
return out
|
||||||
|
return inner
|
||||||
|
|
||||||
# gradient accumulation functions
|
# gradient accumulation functions
|
||||||
|
|
||||||
def split_iterable(it, split_size):
|
def split_iterable(it, split_size):
|
||||||
@@ -254,10 +274,12 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
@cast_torch_tensor
|
||||||
def p_sample_loop(self, *args, **kwargs):
|
def p_sample_loop(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
@cast_torch_tensor
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||||
|
|
||||||
@@ -265,6 +287,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
def sample_batch_size(self, *args, **kwargs):
|
def sample_batch_size(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||||
|
|
||||||
|
@cast_torch_tensor
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
@@ -377,6 +400,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@cast_torch_tensor
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
trainable_unets = self.decoder.unets
|
trainable_unets = self.decoder.unets
|
||||||
@@ -393,6 +417,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@cast_torch_tensor
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.2.32',
|
version = '0.2.33',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -30,6 +30,7 @@ setup(
|
|||||||
'einops-exts>=0.0.3',
|
'einops-exts>=0.0.3',
|
||||||
'embedding-reader',
|
'embedding-reader',
|
||||||
'kornia>=0.5.4',
|
'kornia>=0.5.4',
|
||||||
|
'numpy',
|
||||||
'pillow',
|
'pillow',
|
||||||
'resize-right>=0.0.2',
|
'resize-right>=0.0.2',
|
||||||
'rotary-embedding-torch',
|
'rotary-embedding-torch',
|
||||||
|
|||||||
Reference in New Issue
Block a user