From 05b74be69a2ff309f68e94cb370a717141aa9436 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 22 Apr 2022 15:23:11 -0700 Subject: [PATCH] use null container pattern to cleanup some conditionals, save more cleanup for next week --- dalle2_pytorch/dalle2_pytorch.py | 35 ++++++++++++++++---------------- dalle2_pytorch/vqgan_vae.py | 25 +++++++++++++++++++++++ setup.py | 2 +- 3 files changed, 44 insertions(+), 18 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 2b27df3..cf0870f 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -16,6 +16,7 @@ from einops_exts.torch import EinopsToAndFrom from kornia.filters import gaussian_blur2d from dalle2_pytorch.tokenizer import tokenizer +from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE # use x-clip @@ -48,11 +49,11 @@ def is_list_str(x): return False return all([type(el) == str for el in x]) -def pad_tuple_to_length(t, length): +def pad_tuple_to_length(t, length, fillvalue = None): remain_length = length - len(t) if remain_length <= 0: return t - return (*t, *((None,) * remain_length)) + return (*t, *((fillvalue,) * remain_length)) # for controlling freezing of CLIP @@ -1135,12 +1136,15 @@ class Decoder(nn.Module): # while the rest of the unets are conditioned on the low resolution image produced by previous unet unets = cast_tuple(unet) - vaes = pad_tuple_to_length(cast_tuple(vae), len(unets)) + vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels)) self.unets = nn.ModuleList([]) self.vaes = nn.ModuleList([]) for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)): + assert isinstance(one_unet, Unet) + assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE)) + is_first = ind == 0 latent_dim = one_vae.encoded_dim if exists(one_vae) else None @@ -1152,7 +1156,7 @@ class Decoder(nn.Module): ) self.unets.append(one_unet) - self.vaes.append(one_vae.copy_for_eval() if exists(one_vae) else None) + self.vaes.append(one_vae.copy_for_eval()) # unet image sizes @@ -1362,12 +1366,11 @@ class Decoder(nn.Module): if unet.lowres_cond: lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size) - if exists(vae): - image_size //= (2 ** vae.layers) - shape = (batch_size, vae.encoded_dim, image_size, image_size) + image_size = vae.get_encoded_fmap_size(image_size) + shape = (batch_size, vae.encoded_dim, image_size, image_size) - if exists(lowres_cond_img): - lowres_cond_img = vae.encode(lowres_cond_img) + if exists(lowres_cond_img): + lowres_cond_img = vae.encode(lowres_cond_img) img = self.p_sample_loop( unet, @@ -1378,8 +1381,7 @@ class Decoder(nn.Module): lowres_cond_img = lowres_cond_img ) - if exists(vae): - img = vae.decode(img) + img = vae.decode(img) return img @@ -1415,13 +1417,12 @@ class Decoder(nn.Module): lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None image = resize_image_to(image, target_image_size) - if exists(vae): - vae.eval() - with torch.no_grad(): - image = vae.encode(image) + vae.eval() + with torch.no_grad(): + image = vae.encode(image) - if exists(lowres_cond_img): - lowres_cond_img = vae.encode(lowres_cond_img) + if exists(lowres_cond_img): + lowres_cond_img = vae.encode(lowres_cond_img) return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img) diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index 953cba0..380cd42 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -287,6 +287,28 @@ class VQGanAttention(nn.Module): return out + residual +class NullVQGanVAE(nn.Module): + def __init__( + self, + *, + channels + ): + super().__init__() + self.encoded_dim = channels + self.layers = 0 + + def get_encoded_fmap_size(self, size): + return size + + def copy_for_eval(self): + return self + + def encode(self, x): + return x + + def decode(self, x): + return x + class VQGanVAE(nn.Module): def __init__( self, @@ -407,6 +429,9 @@ class VQGanVAE(nn.Module): self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss + def get_encoded_fmap_size(self, image_size): + return image_size // (2 ** self.layers) + def copy_for_eval(self): device = next(self.parameters()).device vae_copy = copy.deepcopy(self.cpu()) diff --git a/setup.py b/setup.py index c9360b6..9c3781b 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.37', + version = '0.0.38', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',