Compare commits

..

3 Commits
0.1.4 ... 0.1.6

Author SHA1 Message Date
Phil Wang
85ed77d512 fix a potentially huge bug thanks to @CiaoHe https://github.com/lucidrains/DALLE2-pytorch/issues/71 2022-05-07 05:05:54 -07:00
Piero Rolando
fd53fa17db Fix a typo in README (#70)
Change "pyhon" for "python" (correct)
2022-05-06 16:53:36 -07:00
Phil Wang
3676ef4d49 make sure vqgan-vae trainer supports mixed precision 2022-05-06 10:44:16 -07:00
4 changed files with 32 additions and 20 deletions

View File

@@ -902,7 +902,7 @@ Please note that the script internally passes text_embed and image_embed to the
### Usage
```bash
$ pyhon train_diffusion_prior.py
$ python train_diffusion_prior.py
```
The most significant parameters for the script are as follows:

View File

@@ -765,7 +765,7 @@ class DiffusionPriorNetwork(nn.Module):
# but let's just do it right
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
@@ -776,6 +776,7 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings,
text_embed,
time_embed,
image_embed,
learned_queries
), dim = -2)

View File

@@ -3,14 +3,15 @@ import copy
from random import choice
from pathlib import Path
from shutil import rmtree
from PIL import Image
import torch
from torch import nn
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image
from einops import rearrange
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
ema_update_after_step = 2000,
ema_update_every = 10,
apply_grad_penalty_every = 4,
amp = False
):
super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.discr_scaler = GradScaler(enabled = amp)
# create dataset
self.ds = ImageDataset(folder, image_size = image_size)
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl)
img = img.to(device)
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)
with autocast(enabled = self.amp):
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)
self.scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.optim.step()
self.scaler.step(self.optim)
self.scaler.update()
self.optim.zero_grad()
# update discriminator
if exists(self.vae.discr):
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl)
img = img.to(device)
loss = self.vae(img, return_discr_loss = True)
with autocast(enabled = self.amp):
loss = self.vae(img, return_discr_loss = True)
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.discr_optim.step()
self.discr_scaler.step(self.discr_optim)
self.discr_scaler.update()
self.discr_optim.zero_grad()
# log

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.1.4',
version = '0.1.6',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',