Compare commits

...

5 Commits
0.1.2 ... 0.1.7

Author SHA1 Message Date
z
cd5f2c1de4 simulate unrelated captions as a training metric (#66)
* add unrelated embedding metric

* change to torch.roll

Co-authored-by: nousr <z@localhost.com>
Co-authored-by: nousr <>
2022-05-07 05:34:59 -07:00
Phil Wang
85ed77d512 fix a potentially huge bug thanks to @CiaoHe https://github.com/lucidrains/DALLE2-pytorch/issues/71 2022-05-07 05:05:54 -07:00
Piero Rolando
fd53fa17db Fix a typo in README (#70)
Change "pyhon" for "python" (correct)
2022-05-06 16:53:36 -07:00
Phil Wang
3676ef4d49 make sure vqgan-vae trainer supports mixed precision 2022-05-06 10:44:16 -07:00
Phil Wang
28e944f328 make sure openai clip adapter outputs l2normed embeddings 2022-05-06 10:12:03 -07:00
5 changed files with 77 additions and 33 deletions

View File

@@ -902,7 +902,7 @@ Please note that the script internally passes text_embed and image_embed to the
### Usage ### Usage
```bash ```bash
$ pyhon train_diffusion_prior.py $ python train_diffusion_prior.py
``` ```
The most significant parameters for the script are as follows: The most significant parameters for the script are as follows:

View File

@@ -264,7 +264,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text) text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings text_encodings = self.text_encodings
del self.text_encodings del self.text_encodings
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask) return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
@@ -272,7 +272,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
image = resize_image_to(image, self.image_size) image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image)) image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image) image_embed = self.clip.encode_image(image)
return EmbeddedImage(image_embed.float(), None) return EmbeddedImage(l2norm(image_embed.float()), None)
# classifier free guidance functions # classifier free guidance functions
@@ -765,7 +765,7 @@ class DiffusionPriorNetwork(nn.Module):
# but let's just do it right # but let's just do it right
if exists(mask): if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps) time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d') time_embed = rearrange(time_embed, 'b d -> b 1 d')
@@ -776,6 +776,7 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings, text_encodings,
text_embed, text_embed,
time_embed, time_embed,
image_embed,
learned_queries learned_queries
), dim = -2) ), dim = -2)

View File

@@ -3,14 +3,15 @@ import copy
from random import choice from random import choice
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from PIL import Image
import torch import torch
from torch import nn from torch import nn
from torch.cuda.amp import autocast, GradScaler
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
from einops import rearrange from einops import rearrange
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
ema_update_after_step = 2000, ema_update_after_step = 2000,
ema_update_every = 10, ema_update_every = 10,
apply_grad_penalty_every = 4, apply_grad_penalty_every = 4,
amp = False
): ):
super().__init__() super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE' assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd) self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd) self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.discr_scaler = GradScaler(enabled = amp)
# create dataset # create dataset
self.ds = ImageDataset(folder, image_size = image_size) self.ds = ImageDataset(folder, image_size = image_size)
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl) img = next(self.dl)
img = img.to(device) img = img.to(device)
loss = self.vae( with autocast(enabled = self.amp):
img, loss = self.vae(
return_loss = True, img,
apply_grad_penalty = apply_grad_penalty return_loss = True,
) apply_grad_penalty = apply_grad_penalty
)
self.scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward() self.scaler.step(self.optim)
self.scaler.update()
self.optim.step()
self.optim.zero_grad() self.optim.zero_grad()
# update discriminator # update discriminator
if exists(self.vae.discr): if exists(self.vae.discr):
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl) img = next(self.dl)
img = img.to(device) img = img.to(device)
loss = self.vae(img, return_discr_loss = True) with autocast(enabled = self.amp):
loss = self.vae(img, return_discr_loss = True)
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward() self.discr_scaler.step(self.discr_optim)
self.discr_scaler.update()
self.discr_optim.step()
self.discr_optim.zero_grad() self.discr_optim.zero_grad()
# log # log

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.1.2', version = '0.1.6',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -46,28 +46,60 @@ def save_model(save_path, state_dict):
print("====================================== Saving checkpoint ======================================") print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth') torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_size, val_set_size, NUM_TEST_EMBEDDINGS, device):
cos = nn.CosineSimilarity(dim=1, eps=1e-6) cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size+val_set_size tstart = train_set_size+val_set_size
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)): for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend), image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
# make a copy of the text embeddings for shuffling
text_embed = torch.tensor(embt[0]).to(device) text_embed = torch.tensor(embt[0]).to(device)
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True) text_embed_shuffled = text_embed.clone()
test_text_cond = dict(text_embed = text_embed)
# roll the text embeddings to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
# prepare the text embedding
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed=text_embed)
# prepare image embeddings
test_image_embeddings = torch.tensor(embi[0]).to(device) test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True) test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond) # predict on the unshuffled text embeddings
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True) predicted_image_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True)
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy() # predict on the shuffled embeddings
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy() predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)}) # calculate similarities
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)}) original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
wandb.log(
{"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)": np.mean(
predicted_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(
unrelated_similarity)})
return np.mean(predicted_similarity - original_similarity) return np.mean(predicted_similarity - original_similarity)