mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 03:54:35 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9359ad2e91 | ||
|
|
9ff228188b | ||
|
|
2d9963d30e |
@@ -911,4 +911,4 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -930,6 +930,72 @@ class SinusoidalPosEmb(nn.Module):
|
||||
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
nn.SiLU()
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
*,
|
||||
cond_dim = None,
|
||||
time_cond_dim = None,
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_mlp = None
|
||||
|
||||
if exists(time_cond_dim):
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_cond_dim, dim_out)
|
||||
)
|
||||
|
||||
self.cross_attn = None
|
||||
|
||||
if exists(cond_dim):
|
||||
self.cross_attn = EinopsToAndFrom(
|
||||
'b c h w',
|
||||
'b (h w) c',
|
||||
CrossAttention(
|
||||
dim = dim_out,
|
||||
context_dim = cond_dim
|
||||
)
|
||||
)
|
||||
|
||||
self.block1 = Block(dim, dim_out, groups = groups)
|
||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None, time_emb = None):
|
||||
h = self.block1(x)
|
||||
|
||||
if exists(self.time_mlp) and exists(time_emb):
|
||||
time_emb = self.time_mlp(time_emb)
|
||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
||||
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
h = self.cross_attn(h, context = cond) + h
|
||||
|
||||
h = self.block2(h)
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class ConvNextBlock(nn.Module):
|
||||
""" https://arxiv.org/abs/2201.03545 """
|
||||
|
||||
@@ -940,8 +1006,7 @@ class ConvNextBlock(nn.Module):
|
||||
*,
|
||||
cond_dim = None,
|
||||
time_cond_dim = None,
|
||||
mult = 2,
|
||||
norm = True
|
||||
mult = 2
|
||||
):
|
||||
super().__init__()
|
||||
need_projection = dim != dim_out
|
||||
@@ -970,7 +1035,7 @@ class ConvNextBlock(nn.Module):
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
ChanLayerNorm(dim) if norm else nn.Identity(),
|
||||
ChanLayerNorm(dim),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
@@ -1082,7 +1147,11 @@ class LinearAttention(nn.Module):
|
||||
|
||||
self.nonlin = nn.GELU()
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1, bias = False),
|
||||
ChanLayerNorm(dim)
|
||||
)
|
||||
|
||||
def forward(self, fmap):
|
||||
h, x, y = self.heads, *fmap.shape[-2:]
|
||||
@@ -1125,7 +1194,9 @@ class Unet(nn.Module):
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7
|
||||
init_conv_kernel_size = 7,
|
||||
block_type = 'resnet',
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
# save locals to take care of some hyperparameters for cascading DDPM
|
||||
@@ -1200,6 +1271,15 @@ class Unet(nn.Module):
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
|
||||
# whether to use resnet or the (improved?) convnext blocks
|
||||
|
||||
if block_type == 'resnet':
|
||||
block_klass = ResnetBlock
|
||||
elif block_type == 'convnext':
|
||||
block_klass = ConvNextBlock
|
||||
else:
|
||||
raise ValueError(f'unimplemented block type {block_type}')
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
@@ -1212,32 +1292,32 @@ class Unet(nn.Module):
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
|
||||
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
out_dim = default(out_dim, channels)
|
||||
self.final_conv = nn.Sequential(
|
||||
ConvNextBlock(dim, dim),
|
||||
block_klass(dim, dim),
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
)
|
||||
|
||||
@@ -1368,10 +1448,10 @@ class Unet(nn.Module):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
||||
x = convnext(x, c, t)
|
||||
for block1, sparse_attn, block2, downsample in self.downs:
|
||||
x = block1(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c, t)
|
||||
x = block2(x, c, t)
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
@@ -1382,11 +1462,11 @@ class Unet(nn.Module):
|
||||
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
|
||||
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
||||
for block1, sparse_attn, block2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = convnext(x, c, t)
|
||||
x = block1(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c, t)
|
||||
x = block2(x, c, t)
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.94',
|
||||
version = '0.0.95',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from embedding_reader import EmbeddingReader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from torch.cuda.amp import autocast,GradScaler
|
||||
|
||||
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
|
||||
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
|
||||
|
||||
|
||||
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
||||
model.eval()
|
||||
@@ -44,6 +46,33 @@ def save_model(save_path, state_dict):
|
||||
print("====================================== Saving checkpoint ======================================")
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
tstart = train_set_size+val_set_size
|
||||
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
|
||||
|
||||
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
|
||||
text_embed = torch.tensor(embt[0]).to(device)
|
||||
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
|
||||
test_text_cond = dict(text_embed = text_embed)
|
||||
|
||||
test_image_embeddings = torch.tensor(embi[0]).to(device)
|
||||
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
|
||||
|
||||
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
|
||||
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
|
||||
|
||||
return np.mean(predicted_similarity - original_similarity)
|
||||
|
||||
|
||||
|
||||
def train(image_embed_dim,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
@@ -137,6 +166,18 @@ def train(image_embed_dim,
|
||||
wandb.log({"Training loss": loss.item(),
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
diff_cosine_sim = report_cosine_sims(diffusion_prior,
|
||||
image_reader,
|
||||
text_reader,
|
||||
train_set_size,
|
||||
val_set_size,
|
||||
NUM_TEST_EMBEDDINGS,
|
||||
device)
|
||||
wandb.log({"Cosine similarity difference": diff_cosine_sim})
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||
|
||||
Reference in New Issue
Block a user