mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
complete inpainting ability using inpaint_image and inpaint_mask passed into sample function for decoder
This commit is contained in:
@@ -2415,20 +2415,51 @@ class Decoder(nn.Module):
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
|
||||
def p_sample_loop_ddpm(
|
||||
self,
|
||||
unet,
|
||||
shape,
|
||||
image_embed,
|
||||
noise_scheduler,
|
||||
predict_x_start = False,
|
||||
learned_variance = False,
|
||||
clip_denoised = True,
|
||||
lowres_cond_img = None,
|
||||
text_encodings = None,
|
||||
cond_scale = 1,
|
||||
is_latent_diffusion = False,
|
||||
lowres_noise_level = None,
|
||||
inpaint_image = None,
|
||||
inpaint_mask = None
|
||||
):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
if exists(inpaint_image):
|
||||
inpaint_image = self.normalize_img(inpaint_image)
|
||||
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
|
||||
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
|
||||
inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
|
||||
inpaint_mask = inpaint_mask.bool()
|
||||
|
||||
if not is_latent_diffusion:
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
|
||||
if exists(inpaint_image):
|
||||
# following the repaint paper
|
||||
# https://arxiv.org/abs/2201.09865
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
img,
|
||||
torch.full((b,), i, device = device, dtype = torch.long),
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
@@ -2440,11 +2471,32 @@ class Decoder(nn.Module):
|
||||
clip_denoised = clip_denoised
|
||||
)
|
||||
|
||||
if exists(inpaint_image):
|
||||
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
||||
|
||||
unnormalize_img = self.unnormalize_img(img)
|
||||
return unnormalize_img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None):
|
||||
def p_sample_loop_ddim(
|
||||
self,
|
||||
unet,
|
||||
shape,
|
||||
image_embed,
|
||||
noise_scheduler,
|
||||
timesteps,
|
||||
eta = 1.,
|
||||
predict_x_start = False,
|
||||
learned_variance = False,
|
||||
clip_denoised = True,
|
||||
lowres_cond_img = None,
|
||||
text_encodings = None,
|
||||
cond_scale = 1,
|
||||
is_latent_diffusion = False,
|
||||
lowres_noise_level = None,
|
||||
inpaint_image = None,
|
||||
inpaint_mask = None
|
||||
):
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
@@ -2452,6 +2504,13 @@ class Decoder(nn.Module):
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
|
||||
if exists(inpaint_image):
|
||||
inpaint_image = self.normalize_img(inpaint_image)
|
||||
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
|
||||
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
|
||||
inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
|
||||
inpaint_mask = inpaint_mask.bool()
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
if not is_latent_diffusion:
|
||||
@@ -2463,6 +2522,12 @@ class Decoder(nn.Module):
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
if exists(inpaint_image):
|
||||
# following the repaint paper
|
||||
# https://arxiv.org/abs/2201.09865
|
||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
|
||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||
|
||||
if learned_variance:
|
||||
@@ -2486,6 +2551,9 @@ class Decoder(nn.Module):
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
if exists(inpaint_image):
|
||||
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
|
||||
|
||||
img = self.unnormalize_img(img)
|
||||
return img
|
||||
|
||||
@@ -2585,6 +2653,8 @@ class Decoder(nn.Module):
|
||||
cond_scale = 1.,
|
||||
stop_at_unet_number = None,
|
||||
distributed = False,
|
||||
inpaint_image = None,
|
||||
inpaint_mask = None
|
||||
):
|
||||
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
||||
|
||||
@@ -2598,6 +2668,8 @@ class Decoder(nn.Module):
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
|
||||
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
@@ -2609,6 +2681,8 @@ class Decoder(nn.Module):
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
||||
|
||||
with context:
|
||||
# prepare low resolution conditioning for upsamplers
|
||||
|
||||
lowres_cond_img = lowres_noise_level = None
|
||||
shape = (batch_size, channel, image_size, image_size)
|
||||
|
||||
@@ -2619,12 +2693,16 @@ class Decoder(nn.Module):
|
||||
lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)
|
||||
lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)
|
||||
|
||||
# latent diffusion
|
||||
|
||||
is_latent_diffusion = isinstance(vae, VQGanVAE)
|
||||
image_size = vae.get_encoded_fmap_size(image_size)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
# denoising loop for image
|
||||
|
||||
img = self.p_sample_loop(
|
||||
unet,
|
||||
shape,
|
||||
@@ -2638,7 +2716,9 @@ class Decoder(nn.Module):
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
is_latent_diffusion = is_latent_diffusion,
|
||||
noise_scheduler = noise_scheduler,
|
||||
timesteps = sample_timesteps
|
||||
timesteps = sample_timesteps,
|
||||
inpaint_image = inpaint_image,
|
||||
inpaint_mask = inpaint_mask
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
|
||||
Reference in New Issue
Block a user