|
|
|
|
@@ -522,7 +522,7 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
|
|
|
|
|
def predict_noise_from_start(self, x_t, t, x0):
|
|
|
|
|
return (
|
|
|
|
|
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
|
|
|
|
|
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
|
|
|
|
|
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -1542,10 +1542,10 @@ class Unet(nn.Module):
|
|
|
|
|
self_attn = False,
|
|
|
|
|
attn_dim_head = 32,
|
|
|
|
|
attn_heads = 16,
|
|
|
|
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
|
|
|
|
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
|
|
|
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
|
|
|
|
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
|
|
|
|
sparse_attn = False,
|
|
|
|
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
|
|
|
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
|
|
|
|
cond_on_text_encodings = False,
|
|
|
|
|
max_text_len = 256,
|
|
|
|
|
cond_on_image_embeds = False,
|
|
|
|
|
@@ -2100,7 +2100,7 @@ class Decoder(nn.Module):
|
|
|
|
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
|
|
|
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
|
|
|
|
use_noise_for_lowres_cond = False, # whether to use Imagen-like noising for low resolution conditioning
|
|
|
|
|
use_blur_for_lowres_cond = True, # whether to use the blur conditioning used in the original cascading ddpm paper, as well as DALL-E2
|
|
|
|
|
use_blur_for_lowres_cond = True, # whether to use the blur conditioning used in the original cascading ddpm paper, as well as DALL-E2
|
|
|
|
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
|
|
|
|
blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
|
|
|
|
|
blur_sigma = 0.6, # cascading ddpm - blur sigma
|
|
|
|
|
@@ -2371,10 +2371,10 @@ class Decoder(nn.Module):
|
|
|
|
|
x = x.clamp(-s, s) / s
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
|
|
|
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
|
|
|
|
|
|
|
|
|
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
|
|
|
|
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level))
|
|
|
|
|
|
|
|
|
|
if learned_variance:
|
|
|
|
|
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
|
|
|
|
@@ -2406,29 +2406,60 @@ class Decoder(nn.Module):
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
|
|
|
|
b, *_, device = *x.shape, x.device
|
|
|
|
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
|
|
|
|
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
noise = torch.randn_like(x)
|
|
|
|
|
# no noise when t == 0
|
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
|
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
|
|
|
|
|
def p_sample_loop_ddpm(
|
|
|
|
|
self,
|
|
|
|
|
unet,
|
|
|
|
|
shape,
|
|
|
|
|
image_embed,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
predict_x_start = False,
|
|
|
|
|
learned_variance = False,
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
lowres_cond_img = None,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
cond_scale = 1,
|
|
|
|
|
is_latent_diffusion = False,
|
|
|
|
|
lowres_noise_level = None,
|
|
|
|
|
inpaint_image = None,
|
|
|
|
|
inpaint_mask = None
|
|
|
|
|
):
|
|
|
|
|
device = self.device
|
|
|
|
|
|
|
|
|
|
b = shape[0]
|
|
|
|
|
img = torch.randn(shape, device = device)
|
|
|
|
|
|
|
|
|
|
if exists(inpaint_image):
|
|
|
|
|
inpaint_image = self.normalize_img(inpaint_image)
|
|
|
|
|
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
|
|
|
|
|
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
|
|
|
|
|
inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
|
|
|
|
|
inpaint_mask = inpaint_mask.bool()
|
|
|
|
|
|
|
|
|
|
if not is_latent_diffusion:
|
|
|
|
|
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
|
|
|
|
|
times = torch.full((b,), i, device = device, dtype = torch.long)
|
|
|
|
|
|
|
|
|
|
if exists(inpaint_image):
|
|
|
|
|
# following the repaint paper
|
|
|
|
|
# https://arxiv.org/abs/2201.09865
|
|
|
|
|
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
|
|
|
|
|
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
|
|
|
|
|
|
|
|
|
img = self.p_sample(
|
|
|
|
|
unet,
|
|
|
|
|
img,
|
|
|
|
|
torch.full((b,), i, device = device, dtype = torch.long),
|
|
|
|
|
times,
|
|
|
|
|
image_embed = image_embed,
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
cond_scale = cond_scale,
|
|
|
|
|
@@ -2440,11 +2471,32 @@ class Decoder(nn.Module):
|
|
|
|
|
clip_denoised = clip_denoised
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if exists(inpaint_image):
|
|
|
|
|
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
|
|
|
|
|
|
|
|
|
unnormalize_img = self.unnormalize_img(img)
|
|
|
|
|
return unnormalize_img
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
|
|
|
|
|
def p_sample_loop_ddim(
|
|
|
|
|
self,
|
|
|
|
|
unet,
|
|
|
|
|
shape,
|
|
|
|
|
image_embed,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
timesteps,
|
|
|
|
|
eta = 1.,
|
|
|
|
|
predict_x_start = False,
|
|
|
|
|
learned_variance = False,
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
lowres_cond_img = None,
|
|
|
|
|
text_encodings = None,
|
|
|
|
|
cond_scale = 1,
|
|
|
|
|
is_latent_diffusion = False,
|
|
|
|
|
lowres_noise_level = None,
|
|
|
|
|
inpaint_image = None,
|
|
|
|
|
inpaint_mask = None
|
|
|
|
|
):
|
|
|
|
|
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
|
|
|
|
|
|
|
|
|
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
|
|
|
|
@@ -2452,6 +2504,13 @@ class Decoder(nn.Module):
|
|
|
|
|
times = list(reversed(times.int().tolist()))
|
|
|
|
|
time_pairs = list(zip(times[:-1], times[1:]))
|
|
|
|
|
|
|
|
|
|
if exists(inpaint_image):
|
|
|
|
|
inpaint_image = self.normalize_img(inpaint_image)
|
|
|
|
|
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
|
|
|
|
|
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
|
|
|
|
|
inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
|
|
|
|
|
inpaint_mask = inpaint_mask.bool()
|
|
|
|
|
|
|
|
|
|
img = torch.randn(shape, device = device)
|
|
|
|
|
|
|
|
|
|
if not is_latent_diffusion:
|
|
|
|
|
@@ -2463,6 +2522,12 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
|
|
|
|
|
|
|
|
|
if exists(inpaint_image):
|
|
|
|
|
# following the repaint paper
|
|
|
|
|
# https://arxiv.org/abs/2201.09865
|
|
|
|
|
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
|
|
|
|
|
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
|
|
|
|
|
|
|
|
|
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
|
|
|
|
|
if learned_variance:
|
|
|
|
|
@@ -2486,6 +2551,9 @@ class Decoder(nn.Module):
|
|
|
|
|
c1 * noise + \
|
|
|
|
|
c2 * pred_noise
|
|
|
|
|
|
|
|
|
|
if exists(inpaint_image):
|
|
|
|
|
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
|
|
|
|
|
|
|
|
|
img = self.unnormalize_img(img)
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
@@ -2585,6 +2653,8 @@ class Decoder(nn.Module):
|
|
|
|
|
cond_scale = 1.,
|
|
|
|
|
stop_at_unet_number = None,
|
|
|
|
|
distributed = False,
|
|
|
|
|
inpaint_image = None,
|
|
|
|
|
inpaint_mask = None
|
|
|
|
|
):
|
|
|
|
|
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
|
|
|
|
|
|
|
|
|
@@ -2598,6 +2668,8 @@ class Decoder(nn.Module):
|
|
|
|
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
|
|
|
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
|
|
|
|
|
|
|
|
|
assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
|
|
|
|
|
|
|
|
|
|
img = None
|
|
|
|
|
is_cuda = next(self.parameters()).is_cuda
|
|
|
|
|
|
|
|
|
|
@@ -2609,6 +2681,8 @@ class Decoder(nn.Module):
|
|
|
|
|
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
|
|
|
|
|
|
|
|
|
with context:
|
|
|
|
|
# prepare low resolution conditioning for upsamplers
|
|
|
|
|
|
|
|
|
|
lowres_cond_img = lowres_noise_level = None
|
|
|
|
|
shape = (batch_size, channel, image_size, image_size)
|
|
|
|
|
|
|
|
|
|
@@ -2619,12 +2693,16 @@ class Decoder(nn.Module):
|
|
|
|
|
lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)
|
|
|
|
|
lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)
|
|
|
|
|
|
|
|
|
|
# latent diffusion
|
|
|
|
|
|
|
|
|
|
is_latent_diffusion = isinstance(vae, VQGanVAE)
|
|
|
|
|
image_size = vae.get_encoded_fmap_size(image_size)
|
|
|
|
|
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
|
|
|
|
|
|
|
|
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
# denoising loop for image
|
|
|
|
|
|
|
|
|
|
img = self.p_sample_loop(
|
|
|
|
|
unet,
|
|
|
|
|
shape,
|
|
|
|
|
@@ -2638,7 +2716,9 @@ class Decoder(nn.Module):
|
|
|
|
|
lowres_noise_level = lowres_noise_level,
|
|
|
|
|
is_latent_diffusion = is_latent_diffusion,
|
|
|
|
|
noise_scheduler = noise_scheduler,
|
|
|
|
|
timesteps = sample_timesteps
|
|
|
|
|
timesteps = sample_timesteps,
|
|
|
|
|
inpaint_image = inpaint_image,
|
|
|
|
|
inpaint_mask = inpaint_mask
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
img = vae.decode(img)
|
|
|
|
|
|