Compare commits

...

2 Commits
0.1.2 ... 0.1.5

Author SHA1 Message Date
Phil Wang
3676ef4d49 make sure vqgan-vae trainer supports mixed precision 2022-05-06 10:44:16 -07:00
Phil Wang
28e944f328 make sure openai clip adapter outputs l2normed embeddings 2022-05-06 10:12:03 -07:00
3 changed files with 31 additions and 20 deletions

View File

@@ -264,7 +264,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text) text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings text_encodings = self.text_encodings
del self.text_encodings del self.text_encodings
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask) return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
@@ -272,7 +272,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
image = resize_image_to(image, self.image_size) image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image)) image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image) image_embed = self.clip.encode_image(image)
return EmbeddedImage(image_embed.float(), None) return EmbeddedImage(l2norm(image_embed.float()), None)
# classifier free guidance functions # classifier free guidance functions

View File

@@ -3,14 +3,15 @@ import copy
from random import choice from random import choice
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from PIL import Image
import torch import torch
from torch import nn from torch import nn
from torch.cuda.amp import autocast, GradScaler
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
from einops import rearrange from einops import rearrange
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
ema_update_after_step = 2000, ema_update_after_step = 2000,
ema_update_every = 10, ema_update_every = 10,
apply_grad_penalty_every = 4, apply_grad_penalty_every = 4,
amp = False
): ):
super().__init__() super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE' assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd) self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd) self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.discr_scaler = GradScaler(enabled = amp)
# create dataset # create dataset
self.ds = ImageDataset(folder, image_size = image_size) self.ds = ImageDataset(folder, image_size = image_size)
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl) img = next(self.dl)
img = img.to(device) img = img.to(device)
loss = self.vae( with autocast(enabled = self.amp):
img, loss = self.vae(
return_loss = True, img,
apply_grad_penalty = apply_grad_penalty return_loss = True,
) apply_grad_penalty = apply_grad_penalty
)
self.scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward() self.scaler.step(self.optim)
self.scaler.update()
self.optim.step()
self.optim.zero_grad() self.optim.zero_grad()
# update discriminator # update discriminator
if exists(self.vae.discr): if exists(self.vae.discr):
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl) img = next(self.dl)
img = img.to(device) img = img.to(device)
loss = self.vae(img, return_discr_loss = True) with autocast(enabled = self.amp):
loss = self.vae(img, return_discr_loss = True)
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward() self.discr_scaler.step(self.discr_optim)
self.discr_scaler.update()
self.discr_optim.step()
self.discr_optim.zero_grad() self.discr_optim.zero_grad()
# log # log

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.1.2', version = '0.1.5',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',