From 2c6c91829d85d19aae511dfa44e0b8b92297ad28 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 22 Apr 2022 11:09:17 -0700 Subject: [PATCH] refactor blurring training augmentation to be taken care of by the decoder, with option to downsample to previous resolution before upsampling (cascading ddpm). this opens up the possibility of cascading latent ddpm --- dalle2_pytorch/__init__.py | 2 + dalle2_pytorch/dalle2_pytorch.py | 96 +++++++++++++++++++++++++------- setup.py | 2 +- 3 files changed, 78 insertions(+), 22 deletions(-) diff --git a/dalle2_pytorch/__init__.py b/dalle2_pytorch/__init__.py index 36b674f..5c3290d 100644 --- a/dalle2_pytorch/__init__.py +++ b/dalle2_pytorch/__init__.py @@ -1,2 +1,4 @@ from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder + +from dalle2_pytorch.vqgan_vae import VQGanVAE from x_clip import CLIP diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index fe948d5..17384d9 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -498,7 +498,7 @@ class DiffusionPrior(nn.Module): raise NotImplementedError() alphas = 1. - betas - alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod = torch.cumprod(alphas, axis = 0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) timesteps, = betas.shape @@ -828,9 +828,6 @@ class Unet(nn.Module): attn_dim_head = 32, attn_heads = 8, lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ - lowres_cond_upsample_mode = 'bilinear', - blur_sigma = 0.1, - blur_kernel_size = 3, sparse_attn = False, sparse_attn_window = 8, # window size for sparse attention attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) @@ -847,9 +844,6 @@ class Unet(nn.Module): # for eventual cascading diffusion self.lowres_cond = lowres_cond - self.lowres_cond_upsample_mode = lowres_cond_upsample_mode - self.lowres_blur_kernel_size = blur_kernel_size - self.lowres_blur_sigma = blur_sigma # determine dimensions @@ -977,13 +971,6 @@ class Unet(nn.Module): assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' if exists(lowres_cond_img): - if self.training: - # when training, blur the low resolution conditional image - blur_sigma = default(blur_sigma, self.lowres_blur_sigma) - blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size) - lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) - - lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode) x = torch.cat((x, lowres_cond_img), dim = 1) # time conditioning @@ -1060,6 +1047,44 @@ class Unet(nn.Module): return self.final_conv(x) +class LowresConditioner(nn.Module): + def __init__( + self, + cond_upsample_mode = 'bilinear', + downsample_first = True, + blur_sigma = 0.1, + blur_kernel_size = 3, + ): + super().__init__() + self.cond_upsample_mode = cond_upsample_mode + self.downsample_first = downsample_first + self.blur_sigma = blur_sigma + self.blur_kernel_size = blur_kernel_size + + def forward( + self, + cond_fmap, + *, + target_image_size, + downsample_image_size = None, + blur_sigma = None, + blur_kernel_size = None + ): + target_image_size = cast_tuple(target_image_size, 2) + + if self.training and self.downsample_first and exists(downsample_image_size): + cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode) + + cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode) + + if self.training: + # when training, blur the low resolution conditional image + blur_sigma = default(blur_sigma, self.blur_sigma) + blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size) + cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) + + return cond_fmap + class Decoder(nn.Module): def __init__( self, @@ -1070,7 +1095,11 @@ class Decoder(nn.Module): cond_drop_prob = 0.2, loss_type = 'l1', beta_schedule = 'cosine', - image_sizes = None # for cascading ddpm, image size at each stage + image_sizes = None, # for cascading ddpm, image size at each stage + lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode + lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur + blur_sigma = 0.1, # cascading ddpm - blur sigma + blur_kernel_size = 3, # cascading ddpm - blur kernel size ): super().__init__() assert isinstance(clip, CLIP) @@ -1097,11 +1126,24 @@ class Decoder(nn.Module): self.image_sizes = image_sizes self.sample_channels = cast_tuple(self.channels, len(image_sizes)) + # cascading ddpm related stuff + lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' + self.to_lowres_cond = LowresConditioner( + cond_upsample_mode = lowres_cond_upsample_mode, + downsample_first = lowres_downsample_first, + blur_sigma = blur_sigma, + blur_kernel_size = blur_kernel_size, + ) + + # classifier free guidance + self.cond_drop_prob = cond_drop_prob + # noise schedule + if beta_schedule == "cosine": betas = cosine_beta_schedule(timesteps) elif beta_schedule == "linear": @@ -1116,7 +1158,7 @@ class Decoder(nn.Module): raise NotImplementedError() alphas = 1. - betas - alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod = torch.cumprod(alphas, axis = 0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) timesteps, = betas.shape @@ -1228,6 +1270,7 @@ class Decoder(nn.Module): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) + return img def q_sample(self, x_start, t, noise=None): @@ -1267,7 +1310,6 @@ class Decoder(nn.Module): @eval_decorator def sample(self, image_embed, text = None, cond_scale = 1.): batch_size = image_embed.shape[0] - channels = self.channels text_encodings = self.get_text_encodings(text) if exists(text) else None @@ -1275,18 +1317,30 @@ class Decoder(nn.Module): for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)): with self.one_unet_in_gpu(unet = unet): - shape = (batch_size, channel, image_size, image_size) - img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img) + lowres_cond_img = self.to_lowres_cond( + img, + target_image_size = image_size + ) if unet.lowres_cond else None + + img = self.p_sample_loop( + unet, + (batch_size, channel, image_size, image_size), + image_embed = image_embed, + text_encodings = text_encodings, + cond_scale = cond_scale, + lowres_cond_img = lowres_cond_img + ) return img def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None): assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) + unet_index = unet_number - 1 unet = self.get_unet(unet_number) - target_image_size = self.image_sizes[unet_number - 1] + target_image_size = self.image_sizes[unet_index] b, c, h, w, device, = *image.shape, image.device @@ -1300,7 +1354,7 @@ class Decoder(nn.Module): text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None - lowres_cond_img = image if unet_number > 1 else None + lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None ddpm_image = resize_image_to(image, target_image_size) return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img) diff --git a/setup.py b/setup.py index 5f41b73..406d421 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.31', + version = '0.0.32', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',