mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
use null container pattern to cleanup some conditionals, save more cleanup for next week
This commit is contained in:
@@ -16,6 +16,7 @@ from einops_exts.torch import EinopsToAndFrom
|
|||||||
from kornia.filters import gaussian_blur2d
|
from kornia.filters import gaussian_blur2d
|
||||||
|
|
||||||
from dalle2_pytorch.tokenizer import tokenizer
|
from dalle2_pytorch.tokenizer import tokenizer
|
||||||
|
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||||
|
|
||||||
# use x-clip
|
# use x-clip
|
||||||
|
|
||||||
@@ -48,11 +49,11 @@ def is_list_str(x):
|
|||||||
return False
|
return False
|
||||||
return all([type(el) == str for el in x])
|
return all([type(el) == str for el in x])
|
||||||
|
|
||||||
def pad_tuple_to_length(t, length):
|
def pad_tuple_to_length(t, length, fillvalue = None):
|
||||||
remain_length = length - len(t)
|
remain_length = length - len(t)
|
||||||
if remain_length <= 0:
|
if remain_length <= 0:
|
||||||
return t
|
return t
|
||||||
return (*t, *((None,) * remain_length))
|
return (*t, *((fillvalue,) * remain_length))
|
||||||
|
|
||||||
# for controlling freezing of CLIP
|
# for controlling freezing of CLIP
|
||||||
|
|
||||||
@@ -1135,12 +1136,15 @@ class Decoder(nn.Module):
|
|||||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||||
|
|
||||||
unets = cast_tuple(unet)
|
unets = cast_tuple(unet)
|
||||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets))
|
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||||
|
|
||||||
self.unets = nn.ModuleList([])
|
self.unets = nn.ModuleList([])
|
||||||
self.vaes = nn.ModuleList([])
|
self.vaes = nn.ModuleList([])
|
||||||
|
|
||||||
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
|
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
|
||||||
|
assert isinstance(one_unet, Unet)
|
||||||
|
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
|
||||||
|
|
||||||
is_first = ind == 0
|
is_first = ind == 0
|
||||||
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
||||||
|
|
||||||
@@ -1152,7 +1156,7 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.unets.append(one_unet)
|
self.unets.append(one_unet)
|
||||||
self.vaes.append(one_vae.copy_for_eval() if exists(one_vae) else None)
|
self.vaes.append(one_vae.copy_for_eval())
|
||||||
|
|
||||||
# unet image sizes
|
# unet image sizes
|
||||||
|
|
||||||
@@ -1362,12 +1366,11 @@ class Decoder(nn.Module):
|
|||||||
if unet.lowres_cond:
|
if unet.lowres_cond:
|
||||||
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
||||||
|
|
||||||
if exists(vae):
|
image_size = vae.get_encoded_fmap_size(image_size)
|
||||||
image_size //= (2 ** vae.layers)
|
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
|
||||||
|
|
||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||||
|
|
||||||
img = self.p_sample_loop(
|
img = self.p_sample_loop(
|
||||||
unet,
|
unet,
|
||||||
@@ -1378,8 +1381,7 @@ class Decoder(nn.Module):
|
|||||||
lowres_cond_img = lowres_cond_img
|
lowres_cond_img = lowres_cond_img
|
||||||
)
|
)
|
||||||
|
|
||||||
if exists(vae):
|
img = vae.decode(img)
|
||||||
img = vae.decode(img)
|
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@@ -1415,13 +1417,12 @@ class Decoder(nn.Module):
|
|||||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||||
image = resize_image_to(image, target_image_size)
|
image = resize_image_to(image, target_image_size)
|
||||||
|
|
||||||
if exists(vae):
|
vae.eval()
|
||||||
vae.eval()
|
with torch.no_grad():
|
||||||
with torch.no_grad():
|
image = vae.encode(image)
|
||||||
image = vae.encode(image)
|
|
||||||
|
|
||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||||
|
|
||||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||||
|
|
||||||
|
|||||||
@@ -287,6 +287,28 @@ class VQGanAttention(nn.Module):
|
|||||||
|
|
||||||
return out + residual
|
return out + residual
|
||||||
|
|
||||||
|
class NullVQGanVAE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channels
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.encoded_dim = channels
|
||||||
|
self.layers = 0
|
||||||
|
|
||||||
|
def get_encoded_fmap_size(self, size):
|
||||||
|
return size
|
||||||
|
|
||||||
|
def copy_for_eval(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
class VQGanVAE(nn.Module):
|
class VQGanVAE(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -407,6 +429,9 @@ class VQGanVAE(nn.Module):
|
|||||||
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
||||||
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
||||||
|
|
||||||
|
def get_encoded_fmap_size(self, image_size):
|
||||||
|
return image_size // (2 ** self.layers)
|
||||||
|
|
||||||
def copy_for_eval(self):
|
def copy_for_eval(self):
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
vae_copy = copy.deepcopy(self.cpu())
|
vae_copy = copy.deepcopy(self.cpu())
|
||||||
|
|||||||
Reference in New Issue
Block a user