use null container pattern to cleanup some conditionals, save more cleanup for next week

This commit is contained in:
Phil Wang
2022-04-22 15:23:11 -07:00
parent a8b5d5d753
commit 05b74be69a
3 changed files with 44 additions and 18 deletions

View File

@@ -16,6 +16,7 @@ from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
# use x-clip # use x-clip
@@ -48,11 +49,11 @@ def is_list_str(x):
return False return False
return all([type(el) == str for el in x]) return all([type(el) == str for el in x])
def pad_tuple_to_length(t, length): def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t) remain_length = length - len(t)
if remain_length <= 0: if remain_length <= 0:
return t return t
return (*t, *((None,) * remain_length)) return (*t, *((fillvalue,) * remain_length))
# for controlling freezing of CLIP # for controlling freezing of CLIP
@@ -1135,12 +1136,15 @@ class Decoder(nn.Module):
# while the rest of the unets are conditioned on the low resolution image produced by previous unet # while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet) unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets)) vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
self.unets = nn.ModuleList([]) self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([]) self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)): for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
is_first = ind == 0 is_first = ind == 0
latent_dim = one_vae.encoded_dim if exists(one_vae) else None latent_dim = one_vae.encoded_dim if exists(one_vae) else None
@@ -1152,7 +1156,7 @@ class Decoder(nn.Module):
) )
self.unets.append(one_unet) self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval() if exists(one_vae) else None) self.vaes.append(one_vae.copy_for_eval())
# unet image sizes # unet image sizes
@@ -1362,8 +1366,7 @@ class Decoder(nn.Module):
if unet.lowres_cond: if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size) lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
if exists(vae): image_size = vae.get_encoded_fmap_size(image_size)
image_size //= (2 ** vae.layers)
shape = (batch_size, vae.encoded_dim, image_size, image_size) shape = (batch_size, vae.encoded_dim, image_size, image_size)
if exists(lowres_cond_img): if exists(lowres_cond_img):
@@ -1378,7 +1381,6 @@ class Decoder(nn.Module):
lowres_cond_img = lowres_cond_img lowres_cond_img = lowres_cond_img
) )
if exists(vae):
img = vae.decode(img) img = vae.decode(img)
return img return img
@@ -1415,7 +1417,6 @@ class Decoder(nn.Module):
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size) image = resize_image_to(image, target_image_size)
if exists(vae):
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
image = vae.encode(image) image = vae.encode(image)

View File

@@ -287,6 +287,28 @@ class VQGanAttention(nn.Module):
return out + residual return out + residual
class NullVQGanVAE(nn.Module):
def __init__(
self,
*,
channels
):
super().__init__()
self.encoded_dim = channels
self.layers = 0
def get_encoded_fmap_size(self, size):
return size
def copy_for_eval(self):
return self
def encode(self, x):
return x
def decode(self, x):
return x
class VQGanVAE(nn.Module): class VQGanVAE(nn.Module):
def __init__( def __init__(
self, self,
@@ -407,6 +429,9 @@ class VQGanVAE(nn.Module):
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
def copy_for_eval(self): def copy_for_eval(self):
device = next(self.parameters()).device device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu()) vae_copy = copy.deepcopy(self.cpu())

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.37', version = '0.0.38',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',