|
|
|
|
@@ -77,6 +77,11 @@ def cast_tuple(val, length = None):
|
|
|
|
|
def module_device(module):
|
|
|
|
|
return next(module.parameters()).device
|
|
|
|
|
|
|
|
|
|
def zero_init_(m):
|
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
|
if exists(m.bias):
|
|
|
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def null_context(*args, **kwargs):
|
|
|
|
|
yield
|
|
|
|
|
@@ -867,7 +872,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
|
|
|
|
|
|
|
|
|
if not exists(mask):
|
|
|
|
|
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
|
|
|
|
mask = torch.any(text_encodings != 0., dim = -1)
|
|
|
|
|
|
|
|
|
|
# classifier free guidance
|
|
|
|
|
|
|
|
|
|
@@ -1200,7 +1205,6 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
if self.condition_on_text_encodings:
|
|
|
|
|
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
|
|
|
|
|
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
|
|
|
|
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
|
|
|
|
|
|
|
|
|
# timestep conditioning from ddpm
|
|
|
|
|
@@ -1218,16 +1222,35 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# decoder
|
|
|
|
|
|
|
|
|
|
def ConvTransposeUpsample(dim, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
|
|
|
|
class PixelShuffleUpsample(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
|
|
|
|
|
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, dim, dim_out = None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
|
|
|
|
|
|
|
|
|
def NearestUpsample(dim, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
|
|
|
|
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
|
|
|
|
)
|
|
|
|
|
self.net = nn.Sequential(
|
|
|
|
|
conv,
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
nn.PixelShuffle(2)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.init_conv_(conv)
|
|
|
|
|
|
|
|
|
|
def init_conv_(self, conv):
|
|
|
|
|
o, i, h, w = conv.weight.shape
|
|
|
|
|
conv_weight = torch.empty(o // 4, i, h, w)
|
|
|
|
|
nn.init.kaiming_uniform_(conv_weight)
|
|
|
|
|
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
|
|
|
|
|
|
|
|
|
|
conv.weight.data.copy_(conv_weight)
|
|
|
|
|
nn.init.zeros_(conv.bias.data)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
|
def Downsample(dim, *, dim_out = None):
|
|
|
|
|
dim_out = default(dim_out, dim)
|
|
|
|
|
@@ -1491,7 +1514,7 @@ class Unet(nn.Module):
|
|
|
|
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
|
|
|
|
memory_efficient = False,
|
|
|
|
|
scale_skip_connection = False,
|
|
|
|
|
nearest_upsample = False,
|
|
|
|
|
pixel_shuffle_upsample = True,
|
|
|
|
|
final_conv_kernel_size = 1,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
@@ -1605,7 +1628,7 @@ class Unet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# upsample klass
|
|
|
|
|
|
|
|
|
|
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
|
|
|
|
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
|
|
|
|
|
|
|
|
|
# give memory efficient unet an initial resnet block
|
|
|
|
|
|
|
|
|
|
@@ -1669,6 +1692,8 @@ class Unet(nn.Module):
|
|
|
|
|
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
|
|
|
|
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
|
|
|
|
|
|
|
|
|
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
|
|
|
|
|
|
|
|
|
# if the current settings for the unet are not correct
|
|
|
|
|
# for cascading DDPM, then reinit the unet with the right settings
|
|
|
|
|
def cast_model_parameters(
|
|
|
|
|
@@ -1793,21 +1818,25 @@ class Unet(nn.Module):
|
|
|
|
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
|
|
|
|
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
|
|
|
|
|
|
|
|
|
if not exists(text_mask):
|
|
|
|
|
text_mask = torch.any(text_encodings != 0., dim = -1)
|
|
|
|
|
|
|
|
|
|
text_tokens = self.text_to_cond(text_encodings)
|
|
|
|
|
|
|
|
|
|
text_tokens = text_tokens[:, :self.max_text_len]
|
|
|
|
|
text_mask = text_mask[:, :self.max_text_len]
|
|
|
|
|
|
|
|
|
|
text_tokens_len = text_tokens.shape[1]
|
|
|
|
|
remainder = self.max_text_len - text_tokens_len
|
|
|
|
|
|
|
|
|
|
if remainder > 0:
|
|
|
|
|
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
|
|
|
|
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
|
|
|
|
|
|
|
|
|
if exists(text_mask):
|
|
|
|
|
if remainder > 0:
|
|
|
|
|
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
|
|
|
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
|
|
|
|
|
|
|
|
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
|
|
|
|
text_keep_mask = text_mask & text_keep_mask
|
|
|
|
|
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'
|
|
|
|
|
text_keep_mask = text_mask & text_keep_mask
|
|
|
|
|
|
|
|
|
|
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
|
|
|
|
|
|
|
|
|
@@ -2414,9 +2443,6 @@ class Decoder(nn.Module):
|
|
|
|
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
|
|
|
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
|
|
|
|
|
|
|
|
|
if self.condition_on_text_encodings:
|
|
|
|
|
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
|
|
|
|
|
|
|
|
|
img = None
|
|
|
|
|
is_cuda = next(self.parameters()).is_cuda
|
|
|
|
|
|
|
|
|
|
@@ -2500,9 +2526,6 @@ class Decoder(nn.Module):
|
|
|
|
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
|
|
|
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
|
|
|
|
|
|
|
|
|
if self.condition_on_text_encodings:
|
|
|
|
|
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
|
|
|
|
|
|
|
|
|
|
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
|
|
|
|
image = resize_image_to(image, target_image_size)
|
|
|
|
|
|
|
|
|
|
|