Compare commits

...

7 Commits

Author SHA1 Message Date
Phil Wang
3df899f7a4 patch 2022-05-31 09:03:43 -07:00
Aidan Dempster
09534119a1 Fixed non deterministic optimizer creation (#130) 2022-05-31 09:03:20 -07:00
Phil Wang
6f8b90d4d7 add packaging package 2022-05-30 11:45:00 -07:00
Phil Wang
b588286288 fix version 2022-05-30 11:06:34 -07:00
Phil Wang
b693e0be03 default number of resnet blocks per layer in unet to 2 (in imagen it was 3 for base 64x64) 2022-05-30 10:06:48 -07:00
Phil Wang
a0bed30a84 additional conditioning on image embedding by summing to time embeddings (for FiLM like conditioning in subsequent layers), from passage found in paper by @mhh0318 2022-05-30 09:26:51 -07:00
zion
387c5bf774 quick patch for new prior loader (#123) 2022-05-29 16:25:53 -07:00
7 changed files with 75 additions and 35 deletions

View File

@@ -1,3 +1,4 @@
from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer

View File

@@ -1343,10 +1343,11 @@ class Unet(nn.Module):
cond_on_text_encodings = False,
max_text_len = 256,
cond_on_image_embeds = False,
add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
init_dim = None,
init_conv_kernel_size = 7,
resnet_groups = 8,
num_resnet_blocks = 1,
num_resnet_blocks = 2,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
@@ -1396,11 +1397,16 @@ class Unet(nn.Module):
nn.Linear(time_cond_dim, time_cond_dim)
)
self.image_to_cond = nn.Sequential(
self.image_to_tokens = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
self.to_image_hiddens = nn.Sequential(
nn.Linear(image_embed_dim, time_cond_dim),
nn.GELU()
) if cond_on_image_embeds and add_image_embeds_to_time else None
self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim)
@@ -1558,6 +1564,13 @@ class Unet(nn.Module):
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
t = t + image_hiddens
# conditional dropout
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
@@ -1571,7 +1584,7 @@ class Unet(nn.Module):
image_tokens = None
if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed)
image_tokens = self.image_to_tokens(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where(

View File

@@ -1,8 +1,10 @@
from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2])
wd_params = set(params) - no_wd_params
wd_params, no_wd_params = [], []
for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params
def get_optimizer(
@@ -25,8 +27,8 @@ def get_optimizer(
wd_params, no_wd_params = separate_weight_decayable_params(params)
params = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
{'params': wd_params},
{'params': no_wd_params, 'weight_decay': 0},
]
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
import numpy as np
@@ -57,8 +59,7 @@ def num_to_groups(num, divisor):
return arr
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
return __version__
# decorators
@@ -299,7 +300,7 @@ class DiffusionPriorTrainer(nn.Module):
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version(),
version = __version__,
step = self.step.item(),
**kwargs
)
@@ -315,8 +316,8 @@ class DiffusionPriorTrainer(nn.Module):
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
@@ -463,7 +464,7 @@ class DecoderTrainer(nn.Module):
save_obj = dict(
model = self.decoder.state_dict(),
version = get_pkg_version(),
version = __version__,
step = self.step.item(),
**kwargs
)
@@ -486,7 +487,7 @@ class DecoderTrainer(nn.Module):
loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']:
if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)

View File

@@ -0,0 +1 @@
__version__ = '0.6.4'

View File

@@ -1,4 +1,5 @@
from setuptools import setup, find_packages
exec(open('dalle2_pytorch/version.py').read())
setup(
name = 'dalle2-pytorch',
@@ -10,7 +11,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.5.7',
version = __version__,
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -31,6 +32,7 @@ setup(
'embedding-reader',
'kornia>=0.5.4',
'numpy',
'packaging',
'pillow',
'pydantic',
'resize-right>=0.0.2',

View File

@@ -7,15 +7,13 @@ import torch
import clip
from torch import nn
from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch.dataloaders import make_splits, get_reader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer, print_ribbon
from embedding_reader import EmbeddingReader
from tqdm import tqdm
# constants
@@ -31,7 +29,7 @@ def exists(val):
# functions
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
model.eval()
with torch.no_grad():
@@ -39,6 +37,8 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
total_samples = 0.
for image_embeddings, text_data in tqdm(dataloader):
image_embeddings = image_embeddings.to(device)
text_data = text_data.to(device)
batches = image_embeddings.shape[0]
@@ -57,12 +57,14 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
diffusion_prior.eval()
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for test_image_embeddings, text_data in tqdm(dataloader):
test_image_embeddings = test_image_embeddings.to(device)
text_data = text_data.to(device)
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
@@ -240,7 +242,7 @@ def train(
# Training loop
# diffusion prior network
prior_network = DiffusionPriorNetwork(
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
@@ -249,16 +251,16 @@ def train(
ff_dropout = dropout,
normformer = dp_normformer
)
# Load clip model if text-conditioning
if dp_condition_on_text_encodings:
clip_adapter = OpenAIClipAdapter(clip)
else:
clip_adapter = None
# diffusion prior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip_adapter,
image_embed_dim = image_embed_dim,
@@ -296,28 +298,46 @@ def train(
# Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings")
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
if dp_condition_on_text_encodings:
loader_args = dict(**loader_args, meta_url=meta_url)
reader_args = dict(**reader_args, meta_url=meta_url)
img_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader
)
else:
loader_args = dict(**loader_args, txt_url=text_embed_url)
train_loader, eval_loader, test_loader = make_splits(**loader_args)
reader_args = dict(**reader_args, txt_url=text_embed_url)
img_reader, txt_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader,
text_reader=txt_reader
)
### Training code ###
step = 1
step = 1
timer = Timer()
epochs = num_epochs
for _ in range(epochs):
for image, text in tqdm(train_loader):
diffusion_prior.train()
image = image.to(device)
text = text.to(device)
input_args = dict(image_embed=image)
if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text)
@@ -350,9 +370,9 @@ def train(
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
step += 1
trainer.update()