Compare commits

..

3 Commits

Author SHA1 Message Date
Phil Wang
138079ca83 allow for setting beta schedules of unets differently in the decoder, as what was used in the paper was cosine, cosine, linear 2022-06-20 08:56:37 -07:00
zion
f5a906f5d3 prior train script bug fixes (#153) 2022-06-19 15:55:15 -07:00
Phil Wang
0215237fc6 update status 2022-06-19 09:42:24 -07:00
6 changed files with 125 additions and 86 deletions

View File

@@ -27,8 +27,10 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
- DALL-E 2 🚧
## Appreciation

View File

@@ -378,7 +378,7 @@ def sigmoid_beta_schedule(timesteps):
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
class BaseGaussianDiffusion(nn.Module):
class NoiseScheduler(nn.Module):
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
super().__init__()
@@ -472,11 +472,10 @@ class BaseGaussianDiffusion(nn.Module):
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def sample(self, *args, **kwargs):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def p2_reweigh_loss(self, loss, times):
if not self.has_p2_loss_reweighting:
return loss
return loss * extract(self.p2_loss_weight, times, loss.shape)
# diffusion prior
@@ -687,8 +686,7 @@ class Attention(nn.Module):
# attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.dropout(attn)
# aggregate values
@@ -862,7 +860,7 @@ class DiffusionPriorNetwork(nn.Module):
return pred_image_embed
class DiffusionPrior(BaseGaussianDiffusion):
class DiffusionPrior(nn.Module):
def __init__(
self,
net,
@@ -883,7 +881,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__(
super().__init__()
self.noise_scheduler = NoiseScheduler(
beta_schedule = beta_schedule,
timesteps = timesteps,
loss_type = loss_type
@@ -923,6 +923,13 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
# device tracker
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -933,7 +940,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.)
@@ -941,7 +948,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
@@ -955,7 +962,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
@torch.no_grad()
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
device = self.betas.device
device = self.device
b = shape[0]
image_embed = torch.randn(shape, device=device)
@@ -963,7 +970,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
@@ -972,7 +979,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
pred = self.net(
image_embed_noisy,
@@ -986,7 +993,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
target = noise if not self.predict_x_start else image_embed
loss = self.loss_fn(pred, target)
loss = self.noise_scheduler.loss_fn(pred, target)
return loss
@torch.no_grad()
@@ -997,7 +1004,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = self.noise_scheduler.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
return img
@@ -1069,7 +1076,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long)
# scale image embed (Katherine)
@@ -1234,8 +1241,7 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1739,7 +1745,7 @@ class LowresConditioner(nn.Module):
return cond_fmap
class Decoder(BaseGaussianDiffusion):
class Decoder(nn.Module):
def __init__(
self,
unet,
@@ -1752,7 +1758,7 @@ class Decoder(BaseGaussianDiffusion):
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l2',
beta_schedule = 'cosine',
beta_schedule = None,
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
@@ -1774,13 +1780,7 @@ class Decoder(BaseGaussianDiffusion):
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1
):
super().__init__(
beta_schedule = beta_schedule,
timesteps = timesteps,
loss_type = loss_type,
p2_loss_weight_gamma = p2_loss_weight_gamma,
p2_loss_weight_k = p2_loss_weight_k
)
super().__init__()
self.unconditional = unconditional
@@ -1824,6 +1824,8 @@ class Decoder(BaseGaussianDiffusion):
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet)
num_unets = len(unets)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
@@ -1859,6 +1861,24 @@ class Decoder(BaseGaussianDiffusion):
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# create noise schedulers per unet
if not exists(beta_schedule):
beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
self.noise_schedulers = nn.ModuleList([])
for unet_beta_schedule in beta_schedule:
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,
timesteps = timesteps,
loss_type = loss_type,
p2_loss_weight_gamma = p2_loss_weight_gamma,
p2_loss_weight_k = p2_loss_weight_k
)
self.noise_schedulers.append(noise_scheduler)
# unet image sizes
image_sizes = default(image_sizes, (image_size,))
@@ -1908,6 +1928,14 @@ class Decoder(BaseGaussianDiffusion):
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
# device tracker
self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
@@ -1931,7 +1959,7 @@ class Decoder(BaseGaussianDiffusion):
for unet, device in zip(self.unets, devices):
unet.to(device)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
@@ -1942,7 +1970,7 @@ class Decoder(BaseGaussianDiffusion):
if predict_x_start:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
# s is the threshold amount
@@ -1961,14 +1989,14 @@ class Decoder(BaseGaussianDiffusion):
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network
# by an interpolation of the max and min log beta values
# eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(noise_scheduler.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
if self.learned_variance_constrain_frac:
@@ -1980,17 +2008,17 @@ class Decoder(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.betas.device
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.device
b = shape[0]
img = torch.randn(shape, device = device)
@@ -1998,7 +2026,7 @@ class Decoder(BaseGaussianDiffusion):
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
img = self.p_sample(
unet,
img,
@@ -2009,6 +2037,7 @@ class Decoder(BaseGaussianDiffusion):
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
)
@@ -2016,7 +2045,7 @@ class Decoder(BaseGaussianDiffusion):
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
@@ -2027,7 +2056,7 @@ class Decoder(BaseGaussianDiffusion):
# get x_t
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
model_output = unet(
x_noisy,
@@ -2047,11 +2076,10 @@ class Decoder(BaseGaussianDiffusion):
target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target, reduction = 'none')
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
if self.has_p2_loss_reweighting:
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
loss = noise_scheduler.p2_reweigh_loss(loss, times)
loss = loss.mean()
@@ -2066,8 +2094,8 @@ class Decoder(BaseGaussianDiffusion):
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper
@@ -2117,7 +2145,7 @@ class Decoder(BaseGaussianDiffusion):
img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
@@ -2145,7 +2173,8 @@ class Decoder(BaseGaussianDiffusion):
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
is_latent_diffusion = is_latent_diffusion
is_latent_diffusion = is_latent_diffusion,
noise_scheduler = noise_scheduler
)
img = vae.decode(img)
@@ -2171,6 +2200,7 @@ class Decoder(BaseGaussianDiffusion):
unet = self.get_unet(unet_number)
vae = self.vaes[unet_index]
noise_scheduler = self.noise_schedulers[unet_index]
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
@@ -2180,7 +2210,7 @@ class Decoder(BaseGaussianDiffusion):
check_shape(image, 'b c h w', c = self.channels)
assert h >= target_image_size and w >= target_image_size
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed) and not self.unconditional:
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
@@ -2211,7 +2241,7 @@ class Decoder(BaseGaussianDiffusion):
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
# main class

View File

@@ -173,7 +173,7 @@ class DecoderConfig(BaseModel):
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True
image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5

View File

@@ -24,7 +24,9 @@ def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
@@ -574,8 +576,8 @@ def decoder_sample_in_chunks(fn):
class DecoderTrainer(nn.Module):
def __init__(
self,
accelerator,
decoder,
accelerator = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
@@ -589,7 +591,7 @@ class DecoderTrainer(nn.Module):
assert isinstance(decoder, Decoder)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.accelerator = accelerator
self.accelerator = default(accelerator, Accelerator)
self.num_unets = len(decoder.unets)

View File

@@ -1 +1 @@
__version__ = '0.10.1'
__version__ = '0.11.0'

View File

@@ -12,7 +12,6 @@ import wandb
import torch
from torch import nn
from torch.nn.functional import normalize
from torch.utils.data import DataLoader
import numpy as np
@@ -68,13 +67,13 @@ def eval_model(
dataloader: DataLoader,
text_conditioned: bool,
loss_type: str,
phase: str,
tracker_context: str,
tracker: BaseTracker = None,
use_ema: bool = True,
):
trainer.eval()
if trainer.is_main_process():
click.secho(f"Measuring performance on {phase}", fg="green", blink=True)
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
with torch.no_grad():
total_loss = 0.0
@@ -103,7 +102,7 @@ def eval_model(
avg_loss = total_loss / total_samples
stats = {f"{phase}/{loss_type}": avg_loss}
stats = {f"{tracker_context}-{loss_type}": avg_loss}
trainer.print(stats)
if exists(tracker):
@@ -115,7 +114,7 @@ def report_cosine_sims(
dataloader: DataLoader,
text_conditioned: bool,
tracker: BaseTracker,
phase: str = "validation",
tracker_context: str = "validation",
):
trainer.eval()
if trainer.is_main_process():
@@ -141,7 +140,9 @@ def report_cosine_sims(
# roll the text to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / normalize(text_embed_shuffled)
text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(
dim=1, keepdim=True
)
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
@@ -157,18 +158,21 @@ def report_cosine_sims(
)
# prepare the text embedding
text_embed = normalize(text_embedding / text_embedding)
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
# prepare image embeddings
test_image_embeddings = test_image_embeddings / normalize(test_image_embeddings)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(
dim=1, keepdim=True
)
# predict on the unshuffled text embeddings
predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond
)
predicted_image_embeddings = predicted_image_embeddings / normalize(
predicted_image_embeddings = (
predicted_image_embeddings
/ predicted_image_embeddings.norm(dim=1, keepdim=True)
)
# predict on the shuffled embeddings
@@ -176,8 +180,9 @@ def report_cosine_sims(
test_image_embeddings.shape, text_cond_shuffled
)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / normalize(
predicted_unrelated_embeddings = (
predicted_unrelated_embeddings
/ predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
)
# calculate similarities
@@ -191,19 +196,19 @@ def report_cosine_sims(
)
stats = {
f"{phase}/baseline similarity": np.mean(original_similarity),
f"{phase}/similarity with text": np.mean(predicted_similarity),
f"{phase}/similarity with original image": np.mean(
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
f"{tracker_context}/similarity with original image": np.mean(
predicted_img_similarity
),
f"{phase}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{phase}/difference from baseline similarity": np.mean(
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{tracker_context}/difference from baseline similarity": np.mean(
predicted_similarity - original_similarity
),
}
for k, v in stats.items():
trainer.print(f"{phase}/{k}: {v}")
trainer.print(f"{tracker_context}/{k}: {v}")
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
@@ -269,10 +274,10 @@ def train(
# Log on all processes for debugging
tracker.log(
{
"training/loss": loss,
"samples/sec/rank": samples_per_sec,
"samples/seen": samples_seen,
"ema/decay": ema_decay,
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
"metrics/training-loss": loss,
},
step=current_step,
)
@@ -280,12 +285,12 @@ def train(
# Metric Tracking & Checkpointing (outside of timer's scope)
if current_step % EVAL_EVERY == 0:
eval_model(
trainer,
eval_loader,
config.prior.condition_on_text_encodings,
config.prior.loss_type,
"training/validation",
tracker,
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/online-model-validation",
tracker=tracker,
use_ema=False,
)
@@ -293,8 +298,8 @@ def train(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss=config.prior.loss_type,
phase="ema/validation",
loss_type=config.prior.loss_type,
tracker_context="metrics/ema-model-validation",
tracker=tracker,
use_ema=True,
)
@@ -304,7 +309,7 @@ def train(
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
tracker=tracker,
phase="ema/validation",
tracker_context="metrics",
)
if current_step % config.train.save_every == 0:
@@ -320,7 +325,7 @@ def train(
dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
phase="test",
tracker_context="test",
tracker=tracker,
)
@@ -329,7 +334,7 @@ def train(
test_loader,
config.prior.condition_on_text_encodings,
tracker,
phase="test",
tracker_context="test",
)