mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 13:24:25 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5d958713c0 | ||
|
|
0f31980362 | ||
|
|
bee5bf3815 | ||
|
|
350a3d6045 | ||
|
|
1a81670718 | ||
|
|
934c9728dc |
@@ -1,7 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from inspect import isfunction
|
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
@@ -57,7 +56,7 @@ def maybe(fn):
|
|||||||
def default(val, d):
|
def default(val, d):
|
||||||
if exists(val):
|
if exists(val):
|
||||||
return val
|
return val
|
||||||
return d() if isfunction(d) else d
|
return d() if callable(d) else d
|
||||||
|
|
||||||
def cast_tuple(val, length = 1):
|
def cast_tuple(val, length = 1):
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
@@ -314,11 +313,6 @@ def extract(a, t, x_shape):
|
|||||||
out = a.gather(-1, t)
|
out = a.gather(-1, t)
|
||||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||||
|
|
||||||
def noise_like(shape, device, repeat=False):
|
|
||||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
|
||||||
noise = lambda: torch.randn(shape, device=device)
|
|
||||||
return repeat_noise() if repeat else noise()
|
|
||||||
|
|
||||||
def meanflat(x):
|
def meanflat(x):
|
||||||
return x.mean(dim = tuple(range(1, len(x.shape))))
|
return x.mean(dim = tuple(range(1, len(x.shape))))
|
||||||
|
|
||||||
@@ -373,7 +367,7 @@ def quadratic_beta_schedule(timesteps):
|
|||||||
scale = 1000 / timesteps
|
scale = 1000 / timesteps
|
||||||
beta_start = scale * 0.0001
|
beta_start = scale * 0.0001
|
||||||
beta_end = scale * 0.02
|
beta_end = scale * 0.02
|
||||||
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
|
||||||
|
|
||||||
|
|
||||||
def sigmoid_beta_schedule(timesteps):
|
def sigmoid_beta_schedule(timesteps):
|
||||||
@@ -946,10 +940,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
|
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = torch.randn_like(x)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
@@ -1428,6 +1422,7 @@ class Unet(nn.Module):
|
|||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
|
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
|
||||||
|
|
||||||
self.max_text_len = max_text_len
|
self.max_text_len = max_text_len
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||||
@@ -1565,19 +1560,28 @@ class Unet(nn.Module):
|
|||||||
time_tokens = self.to_time_tokens(time_hiddens)
|
time_tokens = self.to_time_tokens(time_hiddens)
|
||||||
t = self.to_time_cond(time_hiddens)
|
t = self.to_time_cond(time_hiddens)
|
||||||
|
|
||||||
# image embedding to be summed to time embedding
|
|
||||||
# discovered by @mhh0318 in the paper
|
|
||||||
|
|
||||||
if exists(image_embed) and exists(self.to_image_hiddens):
|
|
||||||
image_hiddens = self.to_image_hiddens(image_embed)
|
|
||||||
t = t + image_hiddens
|
|
||||||
|
|
||||||
# conditional dropout
|
# conditional dropout
|
||||||
|
|
||||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||||
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
||||||
|
|
||||||
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
|
||||||
|
|
||||||
|
# image embedding to be summed to time embedding
|
||||||
|
# discovered by @mhh0318 in the paper
|
||||||
|
|
||||||
|
if exists(image_embed) and exists(self.to_image_hiddens):
|
||||||
|
image_hiddens = self.to_image_hiddens(image_embed)
|
||||||
|
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
|
||||||
|
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
|
||||||
|
|
||||||
|
image_hiddens = torch.where(
|
||||||
|
image_keep_mask_hidden,
|
||||||
|
image_hiddens,
|
||||||
|
null_image_hiddens
|
||||||
|
)
|
||||||
|
|
||||||
|
t = t + image_hiddens
|
||||||
|
|
||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
@@ -1585,11 +1589,12 @@ class Unet(nn.Module):
|
|||||||
image_tokens = None
|
image_tokens = None
|
||||||
|
|
||||||
if self.cond_on_image_embeds:
|
if self.cond_on_image_embeds:
|
||||||
|
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
|
||||||
image_tokens = self.image_to_tokens(image_embed)
|
image_tokens = self.image_to_tokens(image_embed)
|
||||||
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
image_keep_mask,
|
image_keep_mask_embed,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
null_image_embed
|
null_image_embed
|
||||||
)
|
)
|
||||||
@@ -1956,10 +1961,10 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = torch.randn_like(x)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|||||||
@@ -58,8 +58,15 @@ def num_to_groups(num, divisor):
|
|||||||
arr.append(remainder)
|
arr.append(remainder)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
def get_pkg_version():
|
def clamp(value, min_value = None, max_value = None):
|
||||||
return __version__
|
assert exists(min_value) or exists(max_value)
|
||||||
|
if exists(min_value):
|
||||||
|
value = max(value, min_value)
|
||||||
|
|
||||||
|
if exists(max_value):
|
||||||
|
value = min(value, max_value)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
@@ -227,10 +234,17 @@ class EMA(nn.Module):
|
|||||||
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
||||||
ma_param.data.copy_(current_param.data)
|
ma_param.data.copy_(current_param.data)
|
||||||
|
|
||||||
|
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
|
||||||
|
ma_buffer.data.copy_(current_buffer.data)
|
||||||
|
|
||||||
def get_current_decay(self):
|
def get_current_decay(self):
|
||||||
epoch = max(0, self.step.item() - self.update_after_step - 1)
|
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
|
||||||
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
||||||
return 0. if epoch < 0 else min(self.beta, max(self.min_value, value))
|
|
||||||
|
if epoch <= 0:
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
return clamp(value, min_value = self.min_value, max_value = self.beta)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
step = self.step.item()
|
step = self.step.item()
|
||||||
@@ -521,7 +535,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
loaded_obj = torch.load(str(path))
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
if version.parse(__version__) != loaded_obj['version']:
|
if version.parse(__version__) != loaded_obj['version']:
|
||||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||||
|
|
||||||
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.6.13'
|
__version__ = '0.7.0'
|
||||||
|
|||||||
@@ -211,7 +211,7 @@ def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
|||||||
Loads the model with an appropriate method depending on the tracker
|
Loads the model with an appropriate method depending on the tracker
|
||||||
"""
|
"""
|
||||||
print(print_ribbon(f"Loading model from {recall_source}"))
|
print(print_ribbon(f"Loading model from {recall_source}"))
|
||||||
state_dict = tracker.recall_state_dict(recall_source, **load_config)
|
state_dict = tracker.recall_state_dict(recall_source, **load_config.dict())
|
||||||
trainer.load_state_dict(state_dict["trainer"])
|
trainer.load_state_dict(state_dict["trainer"])
|
||||||
print("Model loaded")
|
print("Model loaded")
|
||||||
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
||||||
|
|||||||
Reference in New Issue
Block a user