Compare commits

...

3 Commits

Author SHA1 Message Date
Phil Wang
9359ad2e91 0.0.95 2022-05-04 10:53:05 -07:00
Phil Wang
9ff228188b offer old resnet blocks, from the original DDPM paper, just in case convnexts are unsuitable for generative work 2022-05-04 10:52:58 -07:00
Kumar R
2d9963d30e Reporting metrics - Cosine similarity. (#55)
* Update train_diffusion_prior.py

* Delete train_diffusion_prior.py

* Cosine similarity logging.

* Update train_diffusion_prior.py

* Report Cosine metrics every N steps.
2022-05-04 08:04:36 -07:00
4 changed files with 143 additions and 22 deletions

View File

@@ -911,4 +911,4 @@ Once built, images will be saved to the same directory the command is invoked
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -930,6 +930,72 @@ class SinusoidalPosEmb(nn.Module):
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
groups = 8
):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
nn.GroupNorm(groups, dim_out),
nn.SiLU()
)
def forward(self, x):
return self.block(x)
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
groups = 8
):
super().__init__()
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out)
)
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim_out,
context_dim = cond_dim
)
)
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None):
h = self.block1(x)
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
h = rearrange(time_emb, 'b c -> b c 1 1') + h
if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context = cond) + h
h = self.block2(h)
return h + self.res_conv(x)
class ConvNextBlock(nn.Module):
""" https://arxiv.org/abs/2201.03545 """
@@ -940,8 +1006,7 @@ class ConvNextBlock(nn.Module):
*,
cond_dim = None,
time_cond_dim = None,
mult = 2,
norm = True
mult = 2
):
super().__init__()
need_projection = dim != dim_out
@@ -970,7 +1035,7 @@ class ConvNextBlock(nn.Module):
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim) if norm else nn.Identity(),
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
@@ -1082,7 +1147,11 @@ class LinearAttention(nn.Module):
self.nonlin = nn.GELU()
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1, bias = False),
ChanLayerNorm(dim)
)
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
@@ -1125,7 +1194,9 @@ class Unet(nn.Module):
max_text_len = 256,
cond_on_image_embeds = False,
init_dim = None,
init_conv_kernel_size = 7
init_conv_kernel_size = 7,
block_type = 'resnet',
**kwargs
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
@@ -1200,6 +1271,15 @@ class Unet(nn.Module):
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# whether to use resnet or the (improved?) convnext blocks
if block_type == 'resnet':
block_klass = ResnetBlock
elif block_type == 'convnext':
block_klass = ConvNextBlock
else:
raise ValueError(f'unimplemented block type {block_type}')
# layers
self.downs = nn.ModuleList([])
@@ -1212,32 +1292,32 @@ class Unet(nn.Module):
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Upsample(dim_in)
]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
ConvNextBlock(dim, dim),
block_klass(dim, dim),
nn.Conv2d(dim, out_dim, 1)
)
@@ -1368,10 +1448,10 @@ class Unet(nn.Module):
hiddens = []
for convnext, sparse_attn, convnext2, downsample in self.downs:
x = convnext(x, c, t)
for block1, sparse_attn, block2, downsample in self.downs:
x = block1(x, c, t)
x = sparse_attn(x)
x = convnext2(x, c, t)
x = block2(x, c, t)
hiddens.append(x)
x = downsample(x)
@@ -1382,11 +1462,11 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t)
for convnext, sparse_attn, convnext2, upsample in self.ups:
for block1, sparse_attn, block2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c, t)
x = block1(x, c, t)
x = sparse_attn(x)
x = convnext2(x, c, t)
x = block2(x, c, t)
x = upsample(x)
return self.final_conv(x)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.94',
version = '0.0.95',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',

View File

@@ -1,21 +1,23 @@
import os
import math
import argparse
import numpy as np
import torch
from torch import nn
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler
import time
from tqdm import tqdm
import wandb
os.environ["WANDB_SILENT"] = "true"
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval()
@@ -44,6 +46,33 @@ def save_model(save_path, state_dict):
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size+val_set_size
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
text_embed = torch.tensor(embt[0]).to(device)
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed = text_embed)
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
return np.mean(predicted_similarity - original_similarity)
def train(image_embed_dim,
image_embed_url,
text_embed_url,
@@ -137,6 +166,18 @@ def train(image_embed_dim,
wandb.log({"Training loss": loss.item(),
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
diff_cosine_sim = report_cosine_sims(diffusion_prior,
image_reader,
text_reader,
train_set_size,
val_set_size,
NUM_TEST_EMBEDDINGS,
device)
wandb.log({"Cosine similarity difference": diff_cosine_sim})
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)