fix repaint

This commit is contained in:
Phil Wang
2022-07-24 15:29:06 -07:00
parent 417ff808e6
commit 62043acb2f
2 changed files with 94 additions and 54 deletions

View File

@@ -516,6 +516,17 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
) )
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape = x_from.shape
noise = default(noise, lambda: torch.randn_like(x_from))
alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)
sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)
alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)
sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
def predict_start_from_noise(self, x_t, t, noise): def predict_start_from_noise(self, x_t, t, noise):
return ( return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
@@ -2432,14 +2443,18 @@ class Decoder(nn.Module):
is_latent_diffusion = False, is_latent_diffusion = False,
lowres_noise_level = None, lowres_noise_level = None,
inpaint_image = None, inpaint_image = None,
inpaint_mask = None inpaint_mask = None,
inpaint_resample_times = 5
): ):
device = self.device device = self.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
if exists(inpaint_image): is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1
if is_inpaint:
inpaint_image = self.normalize_img(inpaint_image) inpaint_image = self.normalize_img(inpaint_image)
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True) inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float() inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
@@ -2449,10 +2464,15 @@ class Decoder(nn.Module):
if not is_latent_diffusion: if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps): for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long) is_last_timestep = time == 0
if exists(inpaint_image): for r in reversed(range(0, resample_times)):
is_last_resample_step = r == 0
times = torch.full((b,), time, device = device, dtype = torch.long)
if is_inpaint:
# following the repaint paper # following the repaint paper
# https://arxiv.org/abs/2201.09865 # https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times) noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
@@ -2473,7 +2493,11 @@ class Decoder(nn.Module):
clip_denoised = clip_denoised clip_denoised = clip_denoised
) )
if exists(inpaint_image): if is_inpaint and not (is_last_timestep or is_last_resample_step):
# in repaint, you renoise and resample up to 10 times every step
img = noise_scheduler.q_sample_from_to(img, times - 1, times)
if is_inpaint:
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask) img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
unnormalize_img = self.unnormalize_img(img) unnormalize_img = self.unnormalize_img(img)
@@ -2497,7 +2521,8 @@ class Decoder(nn.Module):
is_latent_diffusion = False, is_latent_diffusion = False,
lowres_noise_level = None, lowres_noise_level = None,
inpaint_image = None, inpaint_image = None,
inpaint_mask = None inpaint_mask = None,
inpaint_resample_times = 5
): ):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
@@ -2506,7 +2531,10 @@ class Decoder(nn.Module):
times = list(reversed(times.int().tolist())) times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) time_pairs = list(zip(times[:-1], times[1:]))
if exists(inpaint_image): is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1
if is_inpaint:
inpaint_image = self.normalize_img(inpaint_image) inpaint_image = self.normalize_img(inpaint_image)
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True) inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float() inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
@@ -2519,12 +2547,17 @@ class Decoder(nn.Module):
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
is_last_timestep = time_next == 0
for r in reversed(range(0, resample_times)):
is_last_resample_step = r == 0
alpha = alphas[time] alpha = alphas[time]
alpha_next = alphas[time_next] alpha_next = alphas[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long) time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
if exists(inpaint_image): if is_inpaint:
# following the repaint paper # following the repaint paper
# https://arxiv.org/abs/2201.09865 # https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond) noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
@@ -2547,12 +2580,17 @@ class Decoder(nn.Module):
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if time_next > 0 else 0. noise = torch.randn_like(img) if not is_last_timestep else 0.
img = x_start * alpha_next.sqrt() + \ img = x_start * alpha_next.sqrt() + \
c1 * noise + \ c1 * noise + \
c2 * pred_noise c2 * pred_noise
if is_inpaint and not (is_last_timestep or is_last_resample_step):
# in repaint, you renoise and resample up to 10 times every step
time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond)
if exists(inpaint_image): if exists(inpaint_image):
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask) img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
@@ -2658,7 +2696,8 @@ class Decoder(nn.Module):
stop_at_unet_number = None, stop_at_unet_number = None,
distributed = False, distributed = False,
inpaint_image = None, inpaint_image = None,
inpaint_mask = None inpaint_mask = None,
inpaint_resample_times = 5
): ):
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally' assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
@@ -2730,7 +2769,8 @@ class Decoder(nn.Module):
noise_scheduler = noise_scheduler, noise_scheduler = noise_scheduler,
timesteps = sample_timesteps, timesteps = sample_timesteps,
inpaint_image = inpaint_image, inpaint_image = inpaint_image,
inpaint_mask = inpaint_mask inpaint_mask = inpaint_mask,
inpaint_resample_times = inpaint_resample_times
) )
img = vae.decode(img) img = vae.decode(img)

View File

@@ -1 +1 @@
__version__ = '1.0.3' __version__ = '1.0.5'