|
|
|
|
@@ -77,6 +77,11 @@ def cast_tuple(val, length = None):
|
|
|
|
|
def module_device(module):
|
|
|
|
|
return next(module.parameters()).device
|
|
|
|
|
|
|
|
|
|
def zero_init_(m):
|
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
|
if exists(m.bias):
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def null_context(*args, **kwargs):
|
|
|
|
|
yield
|
|
|
|
|
@@ -220,6 +225,7 @@ class XClipAdapter(BaseClipAdapter):
|
|
|
|
|
encoder_output = self.clip.text_transformer(text)
|
|
|
|
|
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
|
|
|
|
|
text_embed = self.clip.to_text_latent(text_cls)
|
|
|
|
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
|
|
|
|
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
@@ -255,6 +261,7 @@ class CoCaAdapter(BaseClipAdapter):
|
|
|
|
|
text = text[..., :self.max_text_len]
|
|
|
|
|
text_mask = text != 0
|
|
|
|
|
text_embed, text_encodings = self.clip.embed_text(text)
|
|
|
|
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
|
|
|
|
return EmbeddedText(text_embed, text_encodings, text_mask)
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
@@ -314,6 +321,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|
|
|
|
|
|
|
|
|
text_embed = self.clip.encode_text(text)
|
|
|
|
|
text_encodings = self.text_encodings
|
|
|
|
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
|
|
|
|
del self.text_encodings
|
|
|
|
|
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
|
|
|
|
|
|
|
|
|
@@ -505,6 +513,12 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def predict_noise_from_start(self, x_t, t, x0):
|
|
|
|
|
return (
|
|
|
|
|
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
|
|
|
|
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def p2_reweigh_loss(self, loss, times):
|
|
|
|
|
if not self.has_p2_loss_reweighting:
|
|
|
|
|
return loss
|
|
|
|
|
@@ -858,7 +872,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
|
|
|
|
|
|
|
|
|
if not exists(mask):
|
|
|
|
|
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
|
|
|
|
mask = torch.any(text_encodings != 0., dim = -1)
|
|
|
|
|
|
|
|
|
|
# classifier free guidance
|
|
|
|
|
|
|
|
|
|
@@ -911,19 +925,23 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
image_size = None,
|
|
|
|
|
image_channels = 3,
|
|
|
|
|
timesteps = 1000,
|
|
|
|
|
sample_timesteps = None,
|
|
|
|
|
cond_drop_prob = 0.,
|
|
|
|
|
loss_type = "l2",
|
|
|
|
|
predict_x_start = True,
|
|
|
|
|
beta_schedule = "cosine",
|
|
|
|
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
|
|
|
|
sampling_clamp_l2norm = False,
|
|
|
|
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
|
|
|
|
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
|
|
|
|
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
|
|
|
|
|
training_clamp_l2norm = False,
|
|
|
|
|
init_image_embed_l2norm = False,
|
|
|
|
|
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
|
|
|
|
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
|
|
|
|
clip_adapter_overrides = dict()
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.sample_timesteps = sample_timesteps
|
|
|
|
|
|
|
|
|
|
self.noise_scheduler = NoiseScheduler(
|
|
|
|
|
beta_schedule = beta_schedule,
|
|
|
|
|
timesteps = timesteps,
|
|
|
|
|
@@ -954,23 +972,32 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
self.condition_on_text_encodings = condition_on_text_encodings
|
|
|
|
|
|
|
|
|
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
|
|
|
|
|
|
|
|
|
self.predict_x_start = predict_x_start
|
|
|
|
|
|
|
|
|
|
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
|
|
|
|
|
|
|
|
|
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
|
|
|
|
|
|
|
|
|
|
# whether to force an l2norm, similar to clipping denoised, when sampling
|
|
|
|
|
|
|
|
|
|
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
|
|
|
|
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
|
|
|
|
|
|
|
|
|
|
self.training_clamp_l2norm = training_clamp_l2norm
|
|
|
|
|
self.init_image_embed_l2norm = init_image_embed_l2norm
|
|
|
|
|
|
|
|
|
|
# device tracker
|
|
|
|
|
|
|
|
|
|
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def device(self):
|
|
|
|
|
return self._dummy.device
|
|
|
|
|
|
|
|
|
|
def l2norm_clamp_embed(self, image_embed):
|
|
|
|
|
return l2norm(image_embed) * self.image_embed_scale
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
|
|
|
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
|
|
|
|
|
|
|
|
|
@@ -978,8 +1005,6 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start:
|
|
|
|
|
x_recon = pred
|
|
|
|
|
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
|
|
|
|
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
|
|
|
|
else:
|
|
|
|
|
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
|
|
|
|
|
|
|
|
|
@@ -1002,21 +1027,81 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
|
|
|
|
|
device = self.device
|
|
|
|
|
|
|
|
|
|
b = shape[0]
|
|
|
|
|
image_embed = torch.randn(shape, device=device)
|
|
|
|
|
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
|
|
|
|
batch, device = shape[0], self.device
|
|
|
|
|
image_embed = torch.randn(shape, device = device)
|
|
|
|
|
|
|
|
|
|
if self.init_image_embed_l2norm:
|
|
|
|
|
image_embed = l2norm(image_embed) * self.image_embed_scale
|
|
|
|
|
|
|
|
|
|
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
|
|
|
|
times = torch.full((b,), i, device = device, dtype = torch.long)
|
|
|
|
|
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
|
|
|
|
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
|
|
|
|
|
|
|
|
|
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
|
|
|
|
image_embed = self.l2norm_clamp_embed(image_embed)
|
|
|
|
|
|
|
|
|
|
return image_embed
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
|
|
|
|
|
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
|
|
|
|
|
|
|
|
|
|
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
|
|
|
|
|
|
|
|
|
times = list(reversed(times.int().tolist()))
|
|
|
|
|
time_pairs = list(zip(times[:-1], times[1:]))
|
|
|
|
|
|
|
|
|
|
image_embed = torch.randn(shape, device = device)
|
|
|
|
|
|
|
|
|
|
if self.init_image_embed_l2norm:
|
|
|
|
|
image_embed = l2norm(image_embed) * self.image_embed_scale
|
|
|
|
|
|
|
|
|
|
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
|
|
|
|
alpha = alphas[time]
|
|
|
|
|
alpha_next = alphas[time_next]
|
|
|
|
|
|
|
|
|
|
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
|
|
|
|
|
|
|
|
|
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
|
|
|
|
|
else:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
|
|
|
|
|
pred_noise = pred
|
|
|
|
|
|
|
|
|
|
if not self.predict_x_start:
|
|
|
|
|
x_start.clamp_(-1., 1.)
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start and self.sampling_clamp_l2norm:
|
|
|
|
|
x_start = self.l2norm_clamp_embed(x_start)
|
|
|
|
|
|
|
|
|
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
|
|
|
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
|
|
|
|
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
|
|
|
|
|
|
|
|
|
image_embed = x_start * alpha_next.sqrt() + \
|
|
|
|
|
c1 * noise + \
|
|
|
|
|
c2 * pred_noise
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start and self.sampling_final_clamp_l2norm:
|
|
|
|
|
image_embed = self.l2norm_clamp_embed(image_embed)
|
|
|
|
|
|
|
|
|
|
return image_embed
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop(self, *args, timesteps = None, **kwargs):
|
|
|
|
|
timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
|
|
|
|
|
assert timesteps <= self.noise_scheduler.num_timesteps
|
|
|
|
|
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
|
|
|
|
|
|
|
|
|
if not is_ddim:
|
|
|
|
|
return self.p_sample_loop_ddpm(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
|
|
|
|
|
|
|
|
|
def p_losses(self, image_embed, times, text_cond, noise = None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(image_embed))
|
|
|
|
|
|
|
|
|
|
@@ -1030,7 +1115,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start and self.training_clamp_l2norm:
|
|
|
|
|
pred = l2norm(pred) * self.image_embed_scale
|
|
|
|
|
pred = self.l2norm_clamp_embed(pred)
|
|
|
|
|
|
|
|
|
|
target = noise if not self.predict_x_start else image_embed
|
|
|
|
|
|
|
|
|
|
@@ -1051,7 +1136,15 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
@eval_decorator
|
|
|
|
|
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
|
|
|
|
|
def sample(
|
|
|
|
|
self,
|
|
|
|
|
text,
|
|
|
|
|
num_samples_per_batch = 2,
|
|
|
|
|
cond_scale = 1.,
|
|
|
|
|
timesteps = None
|
|
|
|
|
):
|
|
|
|
|
timesteps = default(timesteps, self.sample_timesteps)
|
|
|
|
|
|
|
|
|
|
# in the paper, what they did was
|
|
|
|
|
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
|
|
|
|
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
|
|
|
|
@@ -1066,7 +1159,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
if self.condition_on_text_encodings:
|
|
|
|
|
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
|
|
|
|
|
|
|
|
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
|
|
|
|
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
|
|
|
|
|
|
|
|
|
|
# retrieve original unscaled image embed
|
|
|
|
|
|
|
|
|
|
@@ -1129,16 +1222,35 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# decoder
|
|
|
|
|
|
|
|
|
|
def ConvTransposeUpsample(dim, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
|
|
|
|
class PixelShuffleUpsample(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
|
|
|
|
|
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, dim, dim_out = None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
|
|
|
|
|
|
|
|
|
def NearestUpsample(dim, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
|
|
|
|
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
|
|
|
|
)
|
|
|
|
|
self.net = nn.Sequential(
|
|
|
|
|
conv,
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
nn.PixelShuffle(2)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.init_conv_(conv)
|
|
|
|
|
|
|
|
|
|
def init_conv_(self, conv):
|
|
|
|
|
o, i, h, w = conv.weight.shape
|
|
|
|
|
conv_weight = torch.empty(o // 4, i, h, w)
|
|
|
|
|
nn.init.kaiming_uniform_(conv_weight)
|
|
|
|
|
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
|
|
|
|
|
|
|
|
|
|
conv.weight.data.copy_(conv_weight)
|
|
|
|
|
nn.init.zeros_(conv.bias.data)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
|
def Downsample(dim, *, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
@@ -1402,7 +1514,7 @@ class Unet(nn.Module):
|
|
|
|
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
|
|
|
|
memory_efficient = False,
|
|
|
|
|
scale_skip_connection = False,
|
|
|
|
|
nearest_upsample = False,
|
|
|
|
|
pixel_shuffle_upsample = True,
|
|
|
|
|
final_conv_kernel_size = 1,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
@@ -1468,10 +1580,12 @@ class Unet(nn.Module):
|
|
|
|
|
# text encoding conditioning (optional)
|
|
|
|
|
|
|
|
|
|
self.text_to_cond = None
|
|
|
|
|
self.text_embed_dim = None
|
|
|
|
|
|
|
|
|
|
if cond_on_text_encodings:
|
|
|
|
|
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
|
|
|
|
|
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
|
|
|
|
self.text_embed_dim = text_embed_dim
|
|
|
|
|
|
|
|
|
|
# finer control over whether to condition on image embeddings and text encodings
|
|
|
|
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
|
|
|
|
@@ -1514,7 +1628,7 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# upsample klass
|
|
|
|
|
|
|
|
|
|
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
|
|
|
|
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
|
|
|
|
|
|
|
|
|
# give memory efficient unet an initial resnet block
|
|
|
|
|
|
|
|
|
|
@@ -1578,6 +1692,8 @@ class Unet(nn.Module):
|
|
|
|
|
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
|
|
|
|
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
|
|
|
|
|
|
|
|
|
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
|
|
|
|
|
|
|
|
|
# if the current settings for the unet are not correct
|
|
|
|
|
# for cascading DDPM, then reinit the unet with the right settings
|
|
|
|
|
def cast_model_parameters(
|
|
|
|
|
@@ -1700,21 +1816,27 @@ class Unet(nn.Module):
|
|
|
|
|
text_tokens = None
|
|
|
|
|
|
|
|
|
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
|
|
|
|
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
|
|
|
|
|
|
|
|
|
if not exists(text_mask):
|
|
|
|
|
text_mask = torch.any(text_encodings != 0., dim = -1)
|
|
|
|
|
|
|
|
|
|
text_tokens = self.text_to_cond(text_encodings)
|
|
|
|
|
|
|
|
|
|
text_tokens = text_tokens[:, :self.max_text_len]
|
|
|
|
|
text_mask = text_mask[:, :self.max_text_len]
|
|
|
|
|
|
|
|
|
|
text_tokens_len = text_tokens.shape[1]
|
|
|
|
|
remainder = self.max_text_len - text_tokens_len
|
|
|
|
|
|
|
|
|
|
if remainder > 0:
|
|
|
|
|
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
|
|
|
|
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
|
|
|
|
|
|
|
|
|
if exists(text_mask):
|
|
|
|
|
if remainder > 0:
|
|
|
|
|
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
|
|
|
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
|
|
|
|
|
|
|
|
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
|
|
|
|
text_keep_mask = text_mask & text_keep_mask
|
|
|
|
|
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'
|
|
|
|
|
text_keep_mask = text_mask & text_keep_mask
|
|
|
|
|
|
|
|
|
|
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
|
|
|
|
|
|
|
|
|
@@ -1853,6 +1975,7 @@ class Decoder(nn.Module):
|
|
|
|
|
channels = 3,
|
|
|
|
|
vae = tuple(),
|
|
|
|
|
timesteps = 1000,
|
|
|
|
|
sample_timesteps = None,
|
|
|
|
|
image_cond_drop_prob = 0.1,
|
|
|
|
|
text_cond_drop_prob = 0.5,
|
|
|
|
|
loss_type = 'l2',
|
|
|
|
|
@@ -1876,7 +1999,8 @@ class Decoder(nn.Module):
|
|
|
|
|
use_dynamic_thres = False, # from the Imagen paper
|
|
|
|
|
dynamic_thres_percentile = 0.9,
|
|
|
|
|
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
|
|
|
|
p2_loss_weight_k = 1
|
|
|
|
|
p2_loss_weight_k = 1,
|
|
|
|
|
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
@@ -1956,6 +2080,11 @@ class Decoder(nn.Module):
|
|
|
|
|
self.unets.append(one_unet)
|
|
|
|
|
self.vaes.append(one_vae.copy_for_eval())
|
|
|
|
|
|
|
|
|
|
# sampling timesteps, defaults to non-ddim with full timesteps sampling
|
|
|
|
|
|
|
|
|
|
self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)
|
|
|
|
|
self.ddim_sampling_eta = ddim_sampling_eta
|
|
|
|
|
|
|
|
|
|
# create noise schedulers per unet
|
|
|
|
|
|
|
|
|
|
if not exists(beta_schedule):
|
|
|
|
|
@@ -1966,7 +2095,9 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.noise_schedulers = nn.ModuleList([])
|
|
|
|
|
|
|
|
|
|
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
|
|
|
|
|
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
|
|
|
|
|
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
|
|
|
|
|
|
|
|
|
noise_scheduler = NoiseScheduler(
|
|
|
|
|
beta_schedule = unet_beta_schedule,
|
|
|
|
|
timesteps = timesteps,
|
|
|
|
|
@@ -2067,6 +2198,26 @@ class Decoder(nn.Module):
|
|
|
|
|
for unet, device in zip(self.unets, devices):
|
|
|
|
|
unet.to(device)
|
|
|
|
|
|
|
|
|
|
def dynamic_threshold(self, x):
|
|
|
|
|
""" proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """
|
|
|
|
|
|
|
|
|
|
# s is the threshold amount
|
|
|
|
|
# static thresholding would just be s = 1
|
|
|
|
|
s = 1.
|
|
|
|
|
if self.use_dynamic_thres:
|
|
|
|
|
s = torch.quantile(
|
|
|
|
|
rearrange(x, 'b ... -> b (...)').abs(),
|
|
|
|
|
self.dynamic_thres_percentile,
|
|
|
|
|
dim = -1
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
s.clamp_(min = 1.)
|
|
|
|
|
s = s.view(-1, *((1,) * (x.ndim - 1)))
|
|
|
|
|
|
|
|
|
|
# clip by threshold, depending on whether static or dynamic
|
|
|
|
|
x = x.clamp(-s, s) / s
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
|
|
|
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
|
|
|
|
|
|
|
|
|
@@ -2081,21 +2232,7 @@ class Decoder(nn.Module):
|
|
|
|
|
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
|
|
|
|
|
|
|
|
|
if clip_denoised:
|
|
|
|
|
# s is the threshold amount
|
|
|
|
|
# static thresholding would just be s = 1
|
|
|
|
|
s = 1.
|
|
|
|
|
if self.use_dynamic_thres:
|
|
|
|
|
s = torch.quantile(
|
|
|
|
|
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
|
|
|
|
self.dynamic_thres_percentile,
|
|
|
|
|
dim = -1
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
s.clamp_(min = 1.)
|
|
|
|
|
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
|
|
|
|
|
|
|
|
|
# clip by threshold, depending on whether static or dynamic
|
|
|
|
|
x_recon = x_recon.clamp(-s, s) / s
|
|
|
|
|
x_recon = self.dynamic_threshold(x_recon)
|
|
|
|
|
|
|
|
|
|
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
|
|
|
|
|
|
|
|
|
@@ -2125,7 +2262,7 @@ class Decoder(nn.Module):
|
|
|
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
|
|
|
|
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
|
|
|
|
device = self.device
|
|
|
|
|
|
|
|
|
|
b = shape[0]
|
|
|
|
|
@@ -2153,6 +2290,62 @@ class Decoder(nn.Module):
|
|
|
|
|
unnormalize_img = self.unnormalize_img(img)
|
|
|
|
|
return unnormalize_img
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
|
|
|
|
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
|
|
|
|
|
|
|
|
|
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
|
|
|
|
|
|
|
|
|
times = list(reversed(times.int().tolist()))
|
|
|
|
|
time_pairs = list(zip(times[:-1], times[1:]))
|
|
|
|
|
|
|
|
|
|
img = torch.randn(shape, device = device)
|
|
|
|
|
|
|
|
|
|
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
|
|
|
|
alpha = alphas[time]
|
|
|
|
|
alpha_next = alphas[time_next]
|
|
|
|
|
|
|
|
|
|
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
|
|
|
|
|
|
|
|
|
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
if learned_variance:
|
|
|
|
|
pred, _ = pred.chunk(2, dim = 1)
|
|
|
|
|
|
|
|
|
|
if predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
|
|
|
|
else:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
|
|
|
|
pred_noise = pred
|
|
|
|
|
|
|
|
|
|
if clip_denoised:
|
|
|
|
|
x_start = self.dynamic_threshold(x_start)
|
|
|
|
|
|
|
|
|
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
|
|
|
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
|
|
|
|
noise = torch.randn_like(img) if time_next > 0 else 0.
|
|
|
|
|
|
|
|
|
|
img = x_start * alpha_next.sqrt() + \
|
|
|
|
|
c1 * noise + \
|
|
|
|
|
c2 * pred_noise
|
|
|
|
|
|
|
|
|
|
img = self.unnormalize_img(img)
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
|
|
|
|
|
num_timesteps = noise_scheduler.num_timesteps
|
|
|
|
|
|
|
|
|
|
timesteps = default(timesteps, num_timesteps)
|
|
|
|
|
assert timesteps <= num_timesteps
|
|
|
|
|
is_ddim = timesteps < num_timesteps
|
|
|
|
|
|
|
|
|
|
if not is_ddim:
|
|
|
|
|
return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)
|
|
|
|
|
|
|
|
|
|
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
|
|
|
|
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
@@ -2253,7 +2446,7 @@ class Decoder(nn.Module):
|
|
|
|
|
img = None
|
|
|
|
|
is_cuda = next(self.parameters()).is_cuda
|
|
|
|
|
|
|
|
|
|
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
|
|
|
|
|
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
|
|
|
|
|
|
|
|
|
|
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
|
|
|
|
|
|
|
|
|
@@ -2282,7 +2475,8 @@ class Decoder(nn.Module):
|
|
|
|
|
clip_denoised = not is_latent_diffusion,
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
is_latent_diffusion = is_latent_diffusion,
|
|
|
|
|
noise_scheduler = noise_scheduler
|
|
|
|
|
noise_scheduler = noise_scheduler,
|
|
|
|
|
timesteps = sample_timesteps
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
img = vae.decode(img)
|
|
|
|
|
|