Compare commits

...

6 Commits

Author SHA1 Message Date
Phil Wang
5d958713c0 fix classifier free guidance for image hiddens summed to time hiddens, thanks to @xvjiarui for finding this bug 2022-06-13 21:01:50 -07:00
Phil Wang
0f31980362 cleanup 2022-06-07 17:31:38 -07:00
Phil Wang
bee5bf3815 fix for https://github.com/lucidrains/DALLE2-pytorch/issues/143 2022-06-07 09:03:48 -07:00
Phil Wang
350a3d6045 0.6.16 2022-06-06 08:45:46 -07:00
Kashif Rasul
1a81670718 fix quadratic_beta_schedule (#141) 2022-06-06 08:45:14 -07:00
Phil Wang
934c9728dc some cleanup 2022-06-04 16:54:15 -07:00
4 changed files with 47 additions and 28 deletions

View File

@@ -1,7 +1,6 @@
import math import math
import random import random
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction
from functools import partial, wraps from functools import partial, wraps
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
@@ -57,7 +56,7 @@ def maybe(fn):
def default(val, d): def default(val, d):
if exists(val): if exists(val):
return val return val
return d() if isfunction(d) else d return d() if callable(d) else d
def cast_tuple(val, length = 1): def cast_tuple(val, length = 1):
if isinstance(val, list): if isinstance(val, list):
@@ -314,11 +313,6 @@ def extract(a, t, x_shape):
out = a.gather(-1, t) out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1))) return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def meanflat(x): def meanflat(x):
return x.mean(dim = tuple(range(1, len(x.shape)))) return x.mean(dim = tuple(range(1, len(x.shape))))
@@ -373,7 +367,7 @@ def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps scale = 1000 / timesteps
beta_start = scale * 0.0001 beta_start = scale * 0.0001
beta_end = scale * 0.02 beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2 return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
def sigmoid_beta_schedule(timesteps): def sigmoid_beta_schedule(timesteps):
@@ -946,10 +940,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.): def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = noise_like(x.shape, device, repeat_noise) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@@ -1428,6 +1422,7 @@ class Unet(nn.Module):
# for classifier free guidance # for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
self.max_text_len = max_text_len self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
@@ -1565,19 +1560,28 @@ class Unet(nn.Module):
time_tokens = self.to_time_tokens(time_hiddens) time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens) t = self.to_time_cond(time_hiddens)
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
t = t + image_hiddens
# conditional dropout # conditional dropout
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device) image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device) text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1') text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
image_hiddens = torch.where(
image_keep_mask_hidden,
image_hiddens,
null_image_hiddens
)
t = t + image_hiddens
# mask out image embedding depending on condition dropout # mask out image embedding depending on condition dropout
# for classifier free guidance # for classifier free guidance
@@ -1585,11 +1589,12 @@ class Unet(nn.Module):
image_tokens = None image_tokens = None
if self.cond_on_image_embeds: if self.cond_on_image_embeds:
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
image_tokens = self.image_to_tokens(image_embed) image_tokens = self.image_to_tokens(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where( image_tokens = torch.where(
image_keep_mask, image_keep_mask_embed,
image_tokens, image_tokens,
null_image_embed null_image_embed
) )
@@ -1956,10 +1961,10 @@ class Decoder(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False): def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
noise = noise_like(x.shape, device, repeat_noise) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

View File

@@ -58,8 +58,15 @@ def num_to_groups(num, divisor):
arr.append(remainder) arr.append(remainder)
return arr return arr
def get_pkg_version(): def clamp(value, min_value = None, max_value = None):
return __version__ assert exists(min_value) or exists(max_value)
if exists(min_value):
value = max(value, min_value)
if exists(max_value):
value = min(value, max_value)
return value
# decorators # decorators
@@ -227,10 +234,17 @@ class EMA(nn.Module):
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())): for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
ma_param.data.copy_(current_param.data) ma_param.data.copy_(current_param.data)
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
ma_buffer.data.copy_(current_buffer.data)
def get_current_decay(self): def get_current_decay(self):
epoch = max(0, self.step.item() - self.update_after_step - 1) epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
return 0. if epoch < 0 else min(self.beta, max(self.min_value, value))
if epoch <= 0:
return 0.
return clamp(value, min_value = self.min_value, max_value = self.beta)
def update(self): def update(self):
step = self.step.item() step = self.step.item()
@@ -521,7 +535,7 @@ class DecoderTrainer(nn.Module):
loaded_obj = torch.load(str(path)) loaded_obj = torch.load(str(path))
if version.parse(__version__) != loaded_obj['version']: if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}') print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict) self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])

View File

@@ -1 +1 @@
__version__ = '0.6.13' __version__ = '0.7.0'

View File

@@ -211,7 +211,7 @@ def recall_trainer(tracker, trainer, recall_source=None, **load_config):
Loads the model with an appropriate method depending on the tracker Loads the model with an appropriate method depending on the tracker
""" """
print(print_ribbon(f"Loading model from {recall_source}")) print(print_ribbon(f"Loading model from {recall_source}"))
state_dict = tracker.recall_state_dict(recall_source, **load_config) state_dict = tracker.recall_state_dict(recall_source, **load_config.dict())
trainer.load_state_dict(state_dict["trainer"]) trainer.load_state_dict(state_dict["trainer"])
print("Model loaded") print("Model loaded")
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"] return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]