complete inpainting ability using inpaint_image and inpaint_mask passed into sample function for decoder

This commit is contained in:
Phil Wang
2022-07-19 09:26:55 -07:00
parent d88c7ba56c
commit 723bf0abba
3 changed files with 87 additions and 7 deletions

View File

@@ -1049,8 +1049,8 @@ Once built, images will be saved to the same directory the command is invoked
- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine) - [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
- [x] allow for unet to be able to condition non-cross attention style as well - [x] allow for unet to be able to condition non-cross attention style as well
- [x] speed up inference, read up on papers (ddim) - [x] speed up inference, read up on papers (ddim)
- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865 - [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
## Citations ## Citations

View File

@@ -2415,20 +2415,51 @@ class Decoder(nn.Module):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None): def p_sample_loop_ddpm(
self,
unet,
shape,
image_embed,
noise_scheduler,
predict_x_start = False,
learned_variance = False,
clip_denoised = True,
lowres_cond_img = None,
text_encodings = None,
cond_scale = 1,
is_latent_diffusion = False,
lowres_noise_level = None,
inpaint_image = None,
inpaint_mask = None
):
device = self.device device = self.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
if exists(inpaint_image):
inpaint_image = self.normalize_img(inpaint_image)
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
inpaint_mask = inpaint_mask.bool()
if not is_latent_diffusion: if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps): for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long)
if exists(inpaint_image):
# following the repaint paper
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
img = self.p_sample( img = self.p_sample(
unet, unet,
img, img,
torch.full((b,), i, device = device, dtype = torch.long), times,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
cond_scale = cond_scale, cond_scale = cond_scale,
@@ -2440,11 +2471,32 @@ class Decoder(nn.Module):
clip_denoised = clip_denoised clip_denoised = clip_denoised
) )
if exists(inpaint_image):
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
unnormalize_img = self.unnormalize_img(img) unnormalize_img = self.unnormalize_img(img)
return unnormalize_img return unnormalize_img
@torch.no_grad() @torch.no_grad()
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None): def p_sample_loop_ddim(
self,
unet,
shape,
image_embed,
noise_scheduler,
timesteps,
eta = 1.,
predict_x_start = False,
learned_variance = False,
clip_denoised = True,
lowres_cond_img = None,
text_encodings = None,
cond_scale = 1,
is_latent_diffusion = False,
lowres_noise_level = None,
inpaint_image = None,
inpaint_mask = None
):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1] times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
@@ -2452,6 +2504,13 @@ class Decoder(nn.Module):
times = list(reversed(times.int().tolist())) times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) time_pairs = list(zip(times[:-1], times[1:]))
if exists(inpaint_image):
inpaint_image = self.normalize_img(inpaint_image)
inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
inpaint_mask = inpaint_mask.bool()
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
if not is_latent_diffusion: if not is_latent_diffusion:
@@ -2463,6 +2522,12 @@ class Decoder(nn.Module):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long) time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
if exists(inpaint_image):
# following the repaint paper
# https://arxiv.org/abs/2201.09865
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
if learned_variance: if learned_variance:
@@ -2486,6 +2551,9 @@ class Decoder(nn.Module):
c1 * noise + \ c1 * noise + \
c2 * pred_noise c2 * pred_noise
if exists(inpaint_image):
img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
img = self.unnormalize_img(img) img = self.unnormalize_img(img)
return img return img
@@ -2585,6 +2653,8 @@ class Decoder(nn.Module):
cond_scale = 1., cond_scale = 1.,
stop_at_unet_number = None, stop_at_unet_number = None,
distributed = False, distributed = False,
inpaint_image = None,
inpaint_mask = None
): ):
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally' assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
@@ -2598,6 +2668,8 @@ class Decoder(nn.Module):
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
img = None img = None
is_cuda = next(self.parameters()).is_cuda is_cuda = next(self.parameters()).is_cuda
@@ -2609,6 +2681,8 @@ class Decoder(nn.Module):
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
with context: with context:
# prepare low resolution conditioning for upsamplers
lowres_cond_img = lowres_noise_level = None lowres_cond_img = lowres_noise_level = None
shape = (batch_size, channel, image_size, image_size) shape = (batch_size, channel, image_size, image_size)
@@ -2619,12 +2693,16 @@ class Decoder(nn.Module):
lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device) lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)
lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level) lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)
# latent diffusion
is_latent_diffusion = isinstance(vae, VQGanVAE) is_latent_diffusion = isinstance(vae, VQGanVAE)
image_size = vae.get_encoded_fmap_size(image_size) image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size) shape = (batch_size, vae.encoded_dim, image_size, image_size)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img) lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
# denoising loop for image
img = self.p_sample_loop( img = self.p_sample_loop(
unet, unet,
shape, shape,
@@ -2638,7 +2716,9 @@ class Decoder(nn.Module):
lowres_noise_level = lowres_noise_level, lowres_noise_level = lowres_noise_level,
is_latent_diffusion = is_latent_diffusion, is_latent_diffusion = is_latent_diffusion,
noise_scheduler = noise_scheduler, noise_scheduler = noise_scheduler,
timesteps = sample_timesteps timesteps = sample_timesteps,
inpaint_image = inpaint_image,
inpaint_mask = inpaint_mask
) )
img = vae.decode(img) img = vae.decode(img)

View File

@@ -1 +1 @@
__version__ = '0.25.2' __version__ = '0.26.0'