Compare commits

...

4 Commits
0.1.2 ... 0.1.6

Author SHA1 Message Date
Phil Wang
85ed77d512 fix a potentially huge bug thanks to @CiaoHe https://github.com/lucidrains/DALLE2-pytorch/issues/71 2022-05-07 05:05:54 -07:00
Piero Rolando
fd53fa17db Fix a typo in README (#70)
Change "pyhon" for "python" (correct)
2022-05-06 16:53:36 -07:00
Phil Wang
3676ef4d49 make sure vqgan-vae trainer supports mixed precision 2022-05-06 10:44:16 -07:00
Phil Wang
28e944f328 make sure openai clip adapter outputs l2normed embeddings 2022-05-06 10:12:03 -07:00
4 changed files with 34 additions and 22 deletions

View File

@@ -902,7 +902,7 @@ Please note that the script internally passes text_embed and image_embed to the
### Usage ### Usage
```bash ```bash
$ pyhon train_diffusion_prior.py $ python train_diffusion_prior.py
``` ```
The most significant parameters for the script are as follows: The most significant parameters for the script are as follows:

View File

@@ -264,7 +264,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text) text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings text_encodings = self.text_encodings
del self.text_encodings del self.text_encodings
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask) return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
@@ -272,7 +272,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
image = resize_image_to(image, self.image_size) image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image)) image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image) image_embed = self.clip.encode_image(image)
return EmbeddedImage(image_embed.float(), None) return EmbeddedImage(l2norm(image_embed.float()), None)
# classifier free guidance functions # classifier free guidance functions
@@ -765,7 +765,7 @@ class DiffusionPriorNetwork(nn.Module):
# but let's just do it right # but let's just do it right
if exists(mask): if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps) time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d') time_embed = rearrange(time_embed, 'b d -> b 1 d')
@@ -776,6 +776,7 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings, text_encodings,
text_embed, text_embed,
time_embed, time_embed,
image_embed,
learned_queries learned_queries
), dim = -2) ), dim = -2)

View File

@@ -3,14 +3,15 @@ import copy
from random import choice from random import choice
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from PIL import Image
import torch import torch
from torch import nn from torch import nn
from torch.cuda.amp import autocast, GradScaler
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
from einops import rearrange from einops import rearrange
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
ema_update_after_step = 2000, ema_update_after_step = 2000,
ema_update_every = 10, ema_update_every = 10,
apply_grad_penalty_every = 4, apply_grad_penalty_every = 4,
amp = False
): ):
super().__init__() super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE' assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd) self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd) self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.discr_scaler = GradScaler(enabled = amp)
# create dataset # create dataset
self.ds = ImageDataset(folder, image_size = image_size) self.ds = ImageDataset(folder, image_size = image_size)
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl) img = next(self.dl)
img = img.to(device) img = img.to(device)
loss = self.vae( with autocast(enabled = self.amp):
img, loss = self.vae(
return_loss = True, img,
apply_grad_penalty = apply_grad_penalty return_loss = True,
) apply_grad_penalty = apply_grad_penalty
)
self.scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward() self.scaler.step(self.optim)
self.scaler.update()
self.optim.step()
self.optim.zero_grad() self.optim.zero_grad()
# update discriminator # update discriminator
if exists(self.vae.discr): if exists(self.vae.discr):
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl) img = next(self.dl)
img = img.to(device) img = img.to(device)
loss = self.vae(img, return_discr_loss = True) with autocast(enabled = self.amp):
loss = self.vae(img, return_discr_loss = True)
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward() self.discr_scaler.step(self.discr_optim)
self.discr_scaler.update()
self.discr_optim.step()
self.discr_optim.zero_grad() self.discr_optim.zero_grad()
# log # log

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.1.2', version = '0.1.6',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',