mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 14:44:22 +01:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
830afd3c15 | ||
|
|
8f93729d19 | ||
|
|
cd5f2c1de4 | ||
|
|
85ed77d512 | ||
|
|
fd53fa17db | ||
|
|
3676ef4d49 | ||
|
|
28e944f328 |
@@ -902,7 +902,7 @@ Please note that the script internally passes text_embed and image_embed to the
|
|||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ pyhon train_diffusion_prior.py
|
$ python train_diffusion_prior.py
|
||||||
```
|
```
|
||||||
|
|
||||||
The most significant parameters for the script are as follows:
|
The most significant parameters for the script are as follows:
|
||||||
|
|||||||
@@ -264,7 +264,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
text_embed = self.clip.encode_text(text)
|
text_embed = self.clip.encode_text(text)
|
||||||
text_encodings = self.text_encodings
|
text_encodings = self.text_encodings
|
||||||
del self.text_encodings
|
del self.text_encodings
|
||||||
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
|
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def embed_image(self, image):
|
def embed_image(self, image):
|
||||||
@@ -272,7 +272,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
image = resize_image_to(image, self.image_size)
|
image = resize_image_to(image, self.image_size)
|
||||||
image = self.clip_normalize(unnormalize_img(image))
|
image = self.clip_normalize(unnormalize_img(image))
|
||||||
image_embed = self.clip.encode_image(image)
|
image_embed = self.clip.encode_image(image)
|
||||||
return EmbeddedImage(image_embed.float(), None)
|
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||||
|
|
||||||
# classifier free guidance functions
|
# classifier free guidance functions
|
||||||
|
|
||||||
@@ -706,7 +706,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||||
|
|
||||||
@@ -765,7 +765,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
# but let's just do it right
|
# but let's just do it right
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||||
|
|
||||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||||
@@ -776,6 +776,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
text_encodings,
|
text_encodings,
|
||||||
text_embed,
|
text_embed,
|
||||||
time_embed,
|
time_embed,
|
||||||
|
image_embed,
|
||||||
learned_queries
|
learned_queries
|
||||||
), dim = -2)
|
), dim = -2)
|
||||||
|
|
||||||
@@ -806,6 +807,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
sampling_clamp_l2norm = False,
|
sampling_clamp_l2norm = False,
|
||||||
training_clamp_l2norm = False,
|
training_clamp_l2norm = False,
|
||||||
|
init_image_embed_l2norm = False,
|
||||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||||
clip_adapter_overrides = dict()
|
clip_adapter_overrides = dict()
|
||||||
):
|
):
|
||||||
@@ -844,6 +846,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||||
self.training_clamp_l2norm = training_clamp_l2norm
|
self.training_clamp_l2norm = training_clamp_l2norm
|
||||||
|
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||||
pred = self.net(x, t, **text_cond)
|
pred = self.net(x, t, **text_cond)
|
||||||
@@ -878,11 +881,16 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device=device)
|
image_embed = torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
if self.init_image_embed_l2norm:
|
||||||
|
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||||
return img
|
image_embed = self.p_sample(image_embed, times, text_cond = text_cond)
|
||||||
|
|
||||||
|
return image_embed
|
||||||
|
|
||||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||||
|
|||||||
@@ -3,14 +3,15 @@ import copy
|
|||||||
from random import choice
|
from random import choice
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
from PIL import Image
|
|
||||||
from torchvision.datasets import ImageFolder
|
|
||||||
import torchvision.transforms as T
|
|
||||||
from torch.utils.data import Dataset, DataLoader, random_split
|
from torch.utils.data import Dataset, DataLoader, random_split
|
||||||
|
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from torchvision.datasets import ImageFolder
|
||||||
from torchvision.utils import make_grid, save_image
|
from torchvision.utils import make_grid, save_image
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
|
|||||||
ema_update_after_step = 2000,
|
ema_update_after_step = 2000,
|
||||||
ema_update_every = 10,
|
ema_update_every = 10,
|
||||||
apply_grad_penalty_every = 4,
|
apply_grad_penalty_every = 4,
|
||||||
|
amp = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
||||||
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
|
|||||||
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
||||||
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
||||||
|
|
||||||
|
self.amp = amp
|
||||||
|
self.scaler = GradScaler(enabled = amp)
|
||||||
|
self.discr_scaler = GradScaler(enabled = amp)
|
||||||
|
|
||||||
# create dataset
|
# create dataset
|
||||||
|
|
||||||
self.ds = ImageDataset(folder, image_size = image_size)
|
self.ds = ImageDataset(folder, image_size = image_size)
|
||||||
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
|
|||||||
img = next(self.dl)
|
img = next(self.dl)
|
||||||
img = img.to(device)
|
img = img.to(device)
|
||||||
|
|
||||||
loss = self.vae(
|
with autocast(enabled = self.amp):
|
||||||
img,
|
loss = self.vae(
|
||||||
return_loss = True,
|
img,
|
||||||
apply_grad_penalty = apply_grad_penalty
|
return_loss = True,
|
||||||
)
|
apply_grad_penalty = apply_grad_penalty
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.scaler.scale(loss / self.grad_accum_every).backward()
|
||||||
|
|
||||||
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
||||||
|
|
||||||
(loss / self.grad_accum_every).backward()
|
self.scaler.step(self.optim)
|
||||||
|
self.scaler.update()
|
||||||
self.optim.step()
|
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
# update discriminator
|
# update discriminator
|
||||||
|
|
||||||
if exists(self.vae.discr):
|
if exists(self.vae.discr):
|
||||||
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
|
|||||||
img = next(self.dl)
|
img = next(self.dl)
|
||||||
img = img.to(device)
|
img = img.to(device)
|
||||||
|
|
||||||
loss = self.vae(img, return_discr_loss = True)
|
with autocast(enabled = self.amp):
|
||||||
|
loss = self.vae(img, return_discr_loss = True)
|
||||||
|
|
||||||
|
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
|
||||||
|
|
||||||
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
||||||
|
|
||||||
(loss / self.grad_accum_every).backward()
|
self.discr_scaler.step(self.discr_optim)
|
||||||
|
self.discr_scaler.update()
|
||||||
self.discr_optim.step()
|
|
||||||
self.discr_optim.zero_grad()
|
self.discr_optim.zero_grad()
|
||||||
|
|
||||||
# log
|
# log
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.1.2',
|
version = '0.1.8',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -46,28 +46,60 @@ def save_model(save_path, state_dict):
|
|||||||
print("====================================== Saving checkpoint ======================================")
|
print("====================================== Saving checkpoint ======================================")
|
||||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||||
|
|
||||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
|
|
||||||
|
def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_size, val_set_size, NUM_TEST_EMBEDDINGS, device):
|
||||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||||
|
|
||||||
tstart = train_set_size+val_set_size
|
tstart = train_set_size+val_set_size
|
||||||
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
|
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
|
||||||
|
|
||||||
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
|
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend), image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
|
||||||
|
# make a copy of the text embeddings for shuffling
|
||||||
text_embed = torch.tensor(embt[0]).to(device)
|
text_embed = torch.tensor(embt[0]).to(device)
|
||||||
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
|
text_embed_shuffled = text_embed.clone()
|
||||||
test_text_cond = dict(text_embed = text_embed)
|
|
||||||
|
|
||||||
|
# roll the text embeddings to simulate "unrelated" captions
|
||||||
|
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
|
||||||
|
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||||
|
text_embed_shuffled = text_embed_shuffled / \
|
||||||
|
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||||
|
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
|
||||||
|
|
||||||
|
# prepare the text embedding
|
||||||
|
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
|
||||||
|
test_text_cond = dict(text_embed=text_embed)
|
||||||
|
|
||||||
|
# prepare image embeddings
|
||||||
test_image_embeddings = torch.tensor(embi[0]).to(device)
|
test_image_embeddings = torch.tensor(embi[0]).to(device)
|
||||||
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
|
test_image_embeddings = test_image_embeddings / \
|
||||||
|
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||||
|
|
||||||
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
|
# predict on the unshuffled text embeddings
|
||||||
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
|
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||||
|
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
|
||||||
|
predicted_image_embeddings = predicted_image_embeddings / \
|
||||||
|
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||||
|
|
||||||
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
|
# predict on the shuffled embeddings
|
||||||
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
|
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||||
|
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
|
||||||
|
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||||
|
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||||
|
|
||||||
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
|
# calculate similarities
|
||||||
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
|
original_similarity = cos(
|
||||||
|
text_embed, test_image_embeddings).cpu().numpy()
|
||||||
|
predicted_similarity = cos(
|
||||||
|
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||||
|
unrelated_similarity = cos(
|
||||||
|
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||||
|
|
||||||
|
wandb.log(
|
||||||
|
{"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
|
||||||
|
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)": np.mean(
|
||||||
|
predicted_similarity)})
|
||||||
|
wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(
|
||||||
|
unrelated_similarity)})
|
||||||
|
|
||||||
return np.mean(predicted_similarity - original_similarity)
|
return np.mean(predicted_similarity - original_similarity)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user