mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 20:04:21 +01:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
164d9be444 | ||
|
|
5562ec6be2 | ||
|
|
89ff04cfe2 | ||
|
|
f4016f6302 | ||
|
|
1212f7058d |
@@ -14,6 +14,12 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord
|
|||||||
|
|
||||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||||
|
|
||||||
|
## Status
|
||||||
|
|
||||||
|
- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and <a href="https://github.com/crowsonkb">Katherine's</a> own experiments, validate OpenAI's finding that the extra prior increases variety of generations.
|
||||||
|
|
||||||
|
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -61,6 +61,9 @@ def default(val, d):
|
|||||||
def cast_tuple(val, length = 1):
|
def cast_tuple(val, length = 1):
|
||||||
return val if isinstance(val, tuple) else ((val,) * length)
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
|
def module_device(module):
|
||||||
|
return next(module.parameters()).device
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def null_context(*args, **kwargs):
|
def null_context(*args, **kwargs):
|
||||||
yield
|
yield
|
||||||
@@ -1817,7 +1820,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
self.cuda()
|
self.cuda()
|
||||||
|
|
||||||
devices = [next(unet.parameters()).device for unet in self.unets]
|
devices = [module_device(unet) for unet in self.unets]
|
||||||
self.unets.cpu()
|
self.unets.cpu()
|
||||||
unet.cuda()
|
unet.cuda()
|
||||||
|
|
||||||
@@ -1867,13 +1870,14 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device = device)
|
img = torch.randn(shape, device = device)
|
||||||
|
|
||||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
if not is_latent_diffusion:
|
||||||
|
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||||
img = self.p_sample(
|
img = self.p_sample(
|
||||||
@@ -1893,13 +1897,14 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
unnormalize_img = unnormalize_zero_to_one(img)
|
unnormalize_img = unnormalize_zero_to_one(img)
|
||||||
return unnormalize_img
|
return unnormalize_img
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
|
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
|
|
||||||
x_start = normalize_neg_one_to_one(x_start)
|
if not is_latent_diffusion:
|
||||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
x_start = normalize_neg_one_to_one(x_start)
|
||||||
|
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||||
|
|
||||||
# get x_t
|
# get x_t
|
||||||
|
|
||||||
@@ -1965,6 +1970,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self,
|
self,
|
||||||
image_embed = None,
|
image_embed = None,
|
||||||
text = None,
|
text = None,
|
||||||
|
text_mask = None,
|
||||||
|
text_encodings = None,
|
||||||
batch_size = 1,
|
batch_size = 1,
|
||||||
cond_scale = 1.,
|
cond_scale = 1.,
|
||||||
stop_at_unet_number = None
|
stop_at_unet_number = None
|
||||||
@@ -1974,8 +1981,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
if not self.unconditional:
|
if not self.unconditional:
|
||||||
batch_size = image_embed.shape[0]
|
batch_size = image_embed.shape[0]
|
||||||
|
|
||||||
text_encodings = text_mask = None
|
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||||
if exists(text):
|
assert exists(self.clip)
|
||||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||||
|
|
||||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
@@ -2011,7 +2018,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
predict_x_start = predict_x_start,
|
predict_x_start = predict_x_start,
|
||||||
learned_variance = learned_variance,
|
learned_variance = learned_variance,
|
||||||
clip_denoised = not is_latent_diffusion,
|
clip_denoised = not is_latent_diffusion,
|
||||||
lowres_cond_img = lowres_cond_img
|
lowres_cond_img = lowres_cond_img,
|
||||||
|
is_latent_diffusion = is_latent_diffusion
|
||||||
)
|
)
|
||||||
|
|
||||||
img = vae.decode(img)
|
img = vae.decode(img)
|
||||||
@@ -2027,6 +2035,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
text = None,
|
text = None,
|
||||||
image_embed = None,
|
image_embed = None,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
|
text_mask = None,
|
||||||
unet_number = None
|
unet_number = None
|
||||||
):
|
):
|
||||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||||
@@ -2051,7 +2060,6 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||||
image_embed, _ = self.clip.embed_image(image)
|
image_embed, _ = self.clip.embed_image(image)
|
||||||
|
|
||||||
text_encodings = text_mask = None
|
|
||||||
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||||
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||||
@@ -2070,12 +2078,14 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
image = aug(image)
|
image = aug(image)
|
||||||
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
|
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
|
||||||
|
|
||||||
|
is_latent_diffusion = not isinstance(vae, NullVQGanVAE)
|
||||||
|
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
image = vae.encode(image)
|
image = vae.encode(image)
|
||||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||||
|
|
||||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|
||||||
@@ -2107,7 +2117,7 @@ class DALLE2(nn.Module):
|
|||||||
prior_cond_scale = 1.,
|
prior_cond_scale = 1.,
|
||||||
return_pil_images = False
|
return_pil_images = False
|
||||||
):
|
):
|
||||||
device = next(self.parameters()).device
|
device = module_device(self)
|
||||||
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
|
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
|
||||||
|
|
||||||
if isinstance(text, str) or is_list_str(text):
|
if isinstance(text, str) or is_list_str(text):
|
||||||
|
|||||||
59
dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
Normal file
59
dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils import data
|
||||||
|
from torchvision import transforms, utils
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# helpers functions
|
||||||
|
|
||||||
|
def cycle(dl):
|
||||||
|
while True:
|
||||||
|
for data in dl:
|
||||||
|
yield data
|
||||||
|
|
||||||
|
# dataset and dataloader
|
||||||
|
|
||||||
|
class Dataset(data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
folder,
|
||||||
|
image_size,
|
||||||
|
exts = ['jpg', 'jpeg', 'png']
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.folder = folder
|
||||||
|
self.image_size = image_size
|
||||||
|
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||||
|
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Resize(image_size),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.CenterCrop(image_size),
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
path = self.paths[index]
|
||||||
|
img = Image.open(path)
|
||||||
|
return self.transform(img)
|
||||||
|
|
||||||
|
def get_images_dataloader(
|
||||||
|
folder,
|
||||||
|
*,
|
||||||
|
batch_size,
|
||||||
|
image_size,
|
||||||
|
shuffle = True,
|
||||||
|
cycle_dl = True,
|
||||||
|
pin_memory = True
|
||||||
|
):
|
||||||
|
ds = Dataset(folder, image_size)
|
||||||
|
dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
|
||||||
|
|
||||||
|
if cycle_dl:
|
||||||
|
dl = cycle(dl)
|
||||||
|
return dl
|
||||||
@@ -47,6 +47,14 @@ def groupby_prefix_and_trim(prefix, d):
|
|||||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||||
return kwargs_without_prefix, kwargs
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
def num_to_groups(num, divisor):
|
||||||
|
groups = num // divisor
|
||||||
|
remainder = num % divisor
|
||||||
|
arr = [divisor] * groups
|
||||||
|
if remainder > 0:
|
||||||
|
arr.append(remainder)
|
||||||
|
return arr
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
def cast_torch_tensor(fn):
|
def cast_torch_tensor(fn):
|
||||||
@@ -179,8 +187,8 @@ class EMA(nn.Module):
|
|||||||
self.online_model = model
|
self.online_model = model
|
||||||
self.ema_model = copy.deepcopy(model)
|
self.ema_model = copy.deepcopy(model)
|
||||||
|
|
||||||
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
|
|
||||||
self.update_every = update_every
|
self.update_every = update_every
|
||||||
|
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
||||||
|
|
||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
@@ -189,14 +197,21 @@ class EMA(nn.Module):
|
|||||||
device = self.initted.device
|
device = self.initted.device
|
||||||
self.ema_model.to(device)
|
self.ema_model.to(device)
|
||||||
|
|
||||||
|
def copy_params_from_model_to_ema(self):
|
||||||
|
self.ema_model.state_dict(self.online_model.state_dict())
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
|
if (self.step % self.update_every) != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.step <= self.update_after_step:
|
||||||
|
self.copy_params_from_model_to_ema()
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.initted:
|
if not self.initted:
|
||||||
self.ema_model.state_dict(self.online_model.state_dict())
|
self.copy_params_from_model_to_ema()
|
||||||
self.initted.data.copy_(torch.Tensor([True]))
|
self.initted.data.copy_(torch.Tensor([True]))
|
||||||
|
|
||||||
self.update_moving_average(self.ema_model, self.online_model)
|
self.update_moving_average(self.ema_model, self.online_model)
|
||||||
@@ -315,6 +330,22 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
# decoder trainer
|
# decoder trainer
|
||||||
|
|
||||||
|
def decoder_sample_in_chunks(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||||
|
if not exists(max_batch_size):
|
||||||
|
return fn(self, *args, **kwargs)
|
||||||
|
|
||||||
|
if self.decoder.unconditional:
|
||||||
|
batch_size = kwargs.get('batch_size')
|
||||||
|
batch_sizes = num_to_groups(batch_size, max_batch_size)
|
||||||
|
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
|
||||||
|
else:
|
||||||
|
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||||
|
|
||||||
|
return torch.cat(outputs, dim = 0)
|
||||||
|
return inner
|
||||||
|
|
||||||
class DecoderTrainer(nn.Module):
|
class DecoderTrainer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -404,15 +435,17 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
@decoder_sample_in_chunks
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
if self.use_ema:
|
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||||
trainable_unets = self.decoder.unets
|
return self.decoder.sample(*args, **kwargs)
|
||||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
|
||||||
|
trainable_unets = self.decoder.unets
|
||||||
|
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||||
|
|
||||||
output = self.decoder.sample(*args, **kwargs)
|
output = self.decoder.sample(*args, **kwargs)
|
||||||
|
|
||||||
if self.use_ema:
|
self.decoder.unets = trainable_unets # restore original training unets
|
||||||
self.decoder.unets = trainable_unets # restore original training unets
|
|
||||||
|
|
||||||
# cast the ema_model unets back to original device
|
# cast the ema_model unets back to original device
|
||||||
for ema in self.ema_unets:
|
for ema in self.ema_unets:
|
||||||
|
|||||||
Reference in New Issue
Block a user