mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
830afd3c15 | ||
|
|
8f93729d19 | ||
|
|
cd5f2c1de4 | ||
|
|
85ed77d512 | ||
|
|
fd53fa17db | ||
|
|
3676ef4d49 | ||
|
|
28e944f328 | ||
|
|
14e63a3f67 | ||
|
|
09e9eaa5a6 | ||
|
|
e6d752cf4a |
@@ -902,7 +902,7 @@ Please note that the script internally passes text_embed and image_embed to the
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
$ pyhon train_diffusion_prior.py
|
||||
$ python train_diffusion_prior.py
|
||||
```
|
||||
|
||||
The most significant parameters for the script are as follows:
|
||||
@@ -967,7 +967,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
|
||||
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
- [ ] train on a toy task, offer in colab
|
||||
@@ -980,6 +980,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -264,7 +264,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
del self.text_encodings
|
||||
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
|
||||
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
@@ -272,7 +272,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(image_embed.float(), None)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
@@ -706,7 +706,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
@@ -765,7 +765,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||
@@ -776,6 +776,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_encodings,
|
||||
text_embed,
|
||||
time_embed,
|
||||
image_embed,
|
||||
learned_queries
|
||||
), dim = -2)
|
||||
|
||||
@@ -805,6 +806,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False,
|
||||
training_clamp_l2norm = False,
|
||||
init_image_embed_l2norm = False,
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
@@ -842,6 +845,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||
self.training_clamp_l2norm = training_clamp_l2norm
|
||||
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
pred = self.net(x, t, **text_cond)
|
||||
@@ -876,11 +881,16 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
image_embed = torch.randn(shape, device=device)
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||
return img
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond)
|
||||
|
||||
return image_embed
|
||||
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
@@ -894,6 +904,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
**text_cond
|
||||
)
|
||||
|
||||
if self.predict_x_start and self.training_clamp_l2norm:
|
||||
pred = l2norm(pred) * self.image_embed_scale
|
||||
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
|
||||
@@ -3,14 +3,15 @@ import copy
|
||||
from random import choice
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from PIL import Image
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torchvision.transforms as T
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
|
||||
import torchvision.transforms as T
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.utils import make_grid, save_image
|
||||
|
||||
from einops import rearrange
|
||||
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
|
||||
ema_update_after_step = 2000,
|
||||
ema_update_every = 10,
|
||||
apply_grad_penalty_every = 4,
|
||||
amp = False
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
||||
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
|
||||
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
||||
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
||||
|
||||
self.amp = amp
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
self.discr_scaler = GradScaler(enabled = amp)
|
||||
|
||||
# create dataset
|
||||
|
||||
self.ds = ImageDataset(folder, image_size = image_size)
|
||||
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(
|
||||
img,
|
||||
return_loss = True,
|
||||
apply_grad_penalty = apply_grad_penalty
|
||||
)
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.vae(
|
||||
img,
|
||||
return_loss = True,
|
||||
apply_grad_penalty = apply_grad_penalty
|
||||
)
|
||||
|
||||
|
||||
self.scaler.scale(loss / self.grad_accum_every).backward()
|
||||
|
||||
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.optim.step()
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
self.optim.zero_grad()
|
||||
|
||||
|
||||
# update discriminator
|
||||
|
||||
if exists(self.vae.discr):
|
||||
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(img, return_discr_loss = True)
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.vae(img, return_discr_loss = True)
|
||||
|
||||
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
|
||||
|
||||
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.discr_optim.step()
|
||||
self.discr_scaler.step(self.discr_optim)
|
||||
self.discr_scaler.update()
|
||||
self.discr_optim.zero_grad()
|
||||
|
||||
# log
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.1.1',
|
||||
version = '0.1.8',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -46,28 +46,60 @@ def save_model(save_path, state_dict):
|
||||
print("====================================== Saving checkpoint ======================================")
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
|
||||
|
||||
def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_size, val_set_size, NUM_TEST_EMBEDDINGS, device):
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
tstart = train_set_size+val_set_size
|
||||
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
|
||||
|
||||
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
|
||||
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend), image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
|
||||
# make a copy of the text embeddings for shuffling
|
||||
text_embed = torch.tensor(embt[0]).to(device)
|
||||
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
|
||||
test_text_cond = dict(text_embed = text_embed)
|
||||
text_embed_shuffled = text_embed.clone()
|
||||
|
||||
# roll the text embeddings to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / \
|
||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
|
||||
|
||||
# prepare the text embedding
|
||||
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
|
||||
test_text_cond = dict(text_embed=text_embed)
|
||||
|
||||
# prepare image embeddings
|
||||
test_image_embeddings = torch.tensor(embi[0]).to(device)
|
||||
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
test_image_embeddings = test_image_embeddings / \
|
||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / \
|
||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
|
||||
# predict on the shuffled embeddings
|
||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
|
||||
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
|
||||
# calculate similarities
|
||||
original_similarity = cos(
|
||||
text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(
|
||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = cos(
|
||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
|
||||
wandb.log(
|
||||
{"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
|
||||
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)": np.mean(
|
||||
predicted_similarity)})
|
||||
wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(
|
||||
unrelated_similarity)})
|
||||
|
||||
return np.mean(predicted_similarity - original_similarity)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user