diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index b1f202f..cf4c197 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -516,6 +516,17 @@ class NoiseScheduler(nn.Module): extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) + def q_sample_from_to(self, x_from, from_t, to_t, noise = None): + shape = x_from.shape + noise = default(noise, lambda: torch.randn_like(x_from)) + + alpha = extract(self.sqrt_alphas_cumprod, from_t, shape) + sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape) + alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape) + sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape) + + return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha + def predict_start_from_noise(self, x_t, t, noise): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - @@ -2432,14 +2443,18 @@ class Decoder(nn.Module): is_latent_diffusion = False, lowres_noise_level = None, inpaint_image = None, - inpaint_mask = None + inpaint_mask = None, + inpaint_resample_times = 5 ): device = self.device b = shape[0] img = torch.randn(shape, device = device) - if exists(inpaint_image): + is_inpaint = exists(inpaint_image) + resample_times = inpaint_resample_times if is_inpaint else 1 + + if is_inpaint: inpaint_image = self.normalize_img(inpaint_image) inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True) inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float() @@ -2449,31 +2464,40 @@ class Decoder(nn.Module): if not is_latent_diffusion: lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) - for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps): - times = torch.full((b,), i, device = device, dtype = torch.long) + for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps): + is_last_timestep = time == 0 - if exists(inpaint_image): - # following the repaint paper - # https://arxiv.org/abs/2201.09865 - noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times) - img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) + for r in reversed(range(0, resample_times)): + is_last_resample_step = r == 0 - img = self.p_sample( - unet, - img, - times, - image_embed = image_embed, - text_encodings = text_encodings, - cond_scale = cond_scale, - lowres_cond_img = lowres_cond_img, - lowres_noise_level = lowres_noise_level, - predict_x_start = predict_x_start, - noise_scheduler = noise_scheduler, - learned_variance = learned_variance, - clip_denoised = clip_denoised - ) + times = torch.full((b,), time, device = device, dtype = torch.long) - if exists(inpaint_image): + if is_inpaint: + # following the repaint paper + # https://arxiv.org/abs/2201.09865 + noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times) + img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) + + img = self.p_sample( + unet, + img, + times, + image_embed = image_embed, + text_encodings = text_encodings, + cond_scale = cond_scale, + lowres_cond_img = lowres_cond_img, + lowres_noise_level = lowres_noise_level, + predict_x_start = predict_x_start, + noise_scheduler = noise_scheduler, + learned_variance = learned_variance, + clip_denoised = clip_denoised + ) + + if is_inpaint and not (is_last_timestep or is_last_resample_step): + # in repaint, you renoise and resample up to 10 times every step + img = noise_scheduler.q_sample_from_to(img, times - 1, times) + + if is_inpaint: img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask) unnormalize_img = self.unnormalize_img(img) @@ -2497,7 +2521,8 @@ class Decoder(nn.Module): is_latent_diffusion = False, lowres_noise_level = None, inpaint_image = None, - inpaint_mask = None + inpaint_mask = None, + inpaint_resample_times = 5 ): batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta @@ -2506,7 +2531,10 @@ class Decoder(nn.Module): times = list(reversed(times.int().tolist())) time_pairs = list(zip(times[:-1], times[1:])) - if exists(inpaint_image): + is_inpaint = exists(inpaint_image) + resample_times = inpaint_resample_times if is_inpaint else 1 + + if is_inpaint: inpaint_image = self.normalize_img(inpaint_image) inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True) inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float() @@ -2519,39 +2547,49 @@ class Decoder(nn.Module): lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): - alpha = alphas[time] - alpha_next = alphas[time_next] + is_last_timestep = time_next == 0 - time_cond = torch.full((batch,), time, device = device, dtype = torch.long) + for r in reversed(range(0, resample_times)): + is_last_resample_step = r == 0 - if exists(inpaint_image): - # following the repaint paper - # https://arxiv.org/abs/2201.09865 - noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond) - img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) + alpha = alphas[time] + alpha_next = alphas[time_next] - pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) + time_cond = torch.full((batch,), time, device = device, dtype = torch.long) - if learned_variance: - pred, _ = pred.chunk(2, dim = 1) + if is_inpaint: + # following the repaint paper + # https://arxiv.org/abs/2201.09865 + noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond) + img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) - if predict_x_start: - x_start = pred - pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred) - else: - x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred) - pred_noise = pred + pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) - if clip_denoised: - x_start = self.dynamic_threshold(x_start) + if learned_variance: + pred, _ = pred.chunk(2, dim = 1) - c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() - c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() - noise = torch.randn_like(img) if time_next > 0 else 0. + if predict_x_start: + x_start = pred + pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred) + else: + x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred) + pred_noise = pred - img = x_start * alpha_next.sqrt() + \ - c1 * noise + \ - c2 * pred_noise + if clip_denoised: + x_start = self.dynamic_threshold(x_start) + + c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() + c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() + noise = torch.randn_like(img) if not is_last_timestep else 0. + + img = x_start * alpha_next.sqrt() + \ + c1 * noise + \ + c2 * pred_noise + + if is_inpaint and not (is_last_timestep or is_last_resample_step): + # in repaint, you renoise and resample up to 10 times every step + time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long) + img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond) if exists(inpaint_image): img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask) @@ -2658,7 +2696,8 @@ class Decoder(nn.Module): stop_at_unet_number = None, distributed = False, inpaint_image = None, - inpaint_mask = None + inpaint_mask = None, + inpaint_resample_times = 5 ): assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally' @@ -2730,7 +2769,8 @@ class Decoder(nn.Module): noise_scheduler = noise_scheduler, timesteps = sample_timesteps, inpaint_image = inpaint_image, - inpaint_mask = inpaint_mask + inpaint_mask = inpaint_mask, + inpaint_resample_times = inpaint_resample_times ) img = vae.decode(img) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 3f6fab6..858de17 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.0.3' +__version__ = '1.0.5'