mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 01:34:19 +01:00
fix repaint
This commit is contained in:
@@ -516,6 +516,17 @@ class NoiseScheduler(nn.Module):
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
|
||||
shape = x_from.shape
|
||||
noise = default(noise, lambda: torch.randn_like(x_from))
|
||||
|
||||
alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)
|
||||
sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)
|
||||
alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)
|
||||
sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)
|
||||
|
||||
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
|
||||
|
||||
def predict_start_from_noise(self, x_t, t, noise):
|
||||
return (
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
||||
@@ -2432,14 +2443,18 @@ class Decoder(nn.Module):
|
||||
is_latent_diffusion = False,
|
||||
lowres_noise_level = None,
|
||||
inpaint_image = None,
|
||||
inpaint_mask = None
|
||||
inpaint_mask = None,
|
||||
inpaint_resample_times = 5
|
||||
):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
if exists(inpaint_image):
|
||||
is_inpaint = exists(inpaint_image)
|
||||
resample_times = inpaint_resample_times if is_inpaint else 1
|
||||
|
||||
if is_inpaint:
|
||||
inpaint_image = self.normalize_img(inpaint_image)
|
||||
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
|
||||
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
|
||||
@@ -2449,31 +2464,40 @@ class Decoder(nn.Module):
|
||||
if not is_latent_diffusion:
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
|
||||
is_last_timestep = time == 0
|
||||
|
||||
if exists(inpaint_image):
|
||||
# following the repaint paper
|
||||
# https://arxiv.org/abs/2201.09865
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
for r in reversed(range(0, resample_times)):
|
||||
is_last_resample_step = r == 0
|
||||
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
img,
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
predict_x_start = predict_x_start,
|
||||
noise_scheduler = noise_scheduler,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = clip_denoised
|
||||
)
|
||||
times = torch.full((b,), time, device = device, dtype = torch.long)
|
||||
|
||||
if exists(inpaint_image):
|
||||
if is_inpaint:
|
||||
# following the repaint paper
|
||||
# https://arxiv.org/abs/2201.09865
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
img,
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
predict_x_start = predict_x_start,
|
||||
noise_scheduler = noise_scheduler,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = clip_denoised
|
||||
)
|
||||
|
||||
if is_inpaint and not (is_last_timestep or is_last_resample_step):
|
||||
# in repaint, you renoise and resample up to 10 times every step
|
||||
img = noise_scheduler.q_sample_from_to(img, times - 1, times)
|
||||
|
||||
if is_inpaint:
|
||||
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
||||
|
||||
unnormalize_img = self.unnormalize_img(img)
|
||||
@@ -2497,7 +2521,8 @@ class Decoder(nn.Module):
|
||||
is_latent_diffusion = False,
|
||||
lowres_noise_level = None,
|
||||
inpaint_image = None,
|
||||
inpaint_mask = None
|
||||
inpaint_mask = None,
|
||||
inpaint_resample_times = 5
|
||||
):
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
||||
|
||||
@@ -2506,7 +2531,10 @@ class Decoder(nn.Module):
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
|
||||
if exists(inpaint_image):
|
||||
is_inpaint = exists(inpaint_image)
|
||||
resample_times = inpaint_resample_times if is_inpaint else 1
|
||||
|
||||
if is_inpaint:
|
||||
inpaint_image = self.normalize_img(inpaint_image)
|
||||
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
|
||||
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
|
||||
@@ -2519,39 +2547,49 @@ class Decoder(nn.Module):
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
||||
alpha = alphas[time]
|
||||
alpha_next = alphas[time_next]
|
||||
is_last_timestep = time_next == 0
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
for r in reversed(range(0, resample_times)):
|
||||
is_last_resample_step = r == 0
|
||||
|
||||
if exists(inpaint_image):
|
||||
# following the repaint paper
|
||||
# https://arxiv.org/abs/2201.09865
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
alpha = alphas[time]
|
||||
alpha_next = alphas[time_next]
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
if is_inpaint:
|
||||
# following the repaint paper
|
||||
# https://arxiv.org/abs/2201.09865
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
|
||||
if predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
|
||||
if clip_denoised:
|
||||
x_start = self.dynamic_threshold(x_start)
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(img) if time_next > 0 else 0.
|
||||
if predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
|
||||
img = x_start * alpha_next.sqrt() + \
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
if clip_denoised:
|
||||
x_start = self.dynamic_threshold(x_start)
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(img) if not is_last_timestep else 0.
|
||||
|
||||
img = x_start * alpha_next.sqrt() + \
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
if is_inpaint and not (is_last_timestep or is_last_resample_step):
|
||||
# in repaint, you renoise and resample up to 10 times every step
|
||||
time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
|
||||
img = noise_scheduler.q_sample_from_to(img, time_cond, time_next_cond)
|
||||
|
||||
if exists(inpaint_image):
|
||||
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
||||
@@ -2658,7 +2696,8 @@ class Decoder(nn.Module):
|
||||
stop_at_unet_number = None,
|
||||
distributed = False,
|
||||
inpaint_image = None,
|
||||
inpaint_mask = None
|
||||
inpaint_mask = None,
|
||||
inpaint_resample_times = 5
|
||||
):
|
||||
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
||||
|
||||
@@ -2730,7 +2769,8 @@ class Decoder(nn.Module):
|
||||
noise_scheduler = noise_scheduler,
|
||||
timesteps = sample_timesteps,
|
||||
inpaint_image = inpaint_image,
|
||||
inpaint_mask = inpaint_mask
|
||||
inpaint_mask = inpaint_mask,
|
||||
inpaint_resample_times = inpaint_resample_times
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.0.3'
|
||||
__version__ = '1.0.4'
|
||||
|
||||
Reference in New Issue
Block a user