Compare commits

...

7 Commits

4 changed files with 20 additions and 17 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
github: [lucidrains]
github: [nousr, Veldrovive, lucidrains]

View File

@@ -146,7 +146,7 @@ def resize_image_to(
scale_factors = target_image_size / orig_image_size
out = resize(image, scale_factors = scale_factors, **kwargs)
else:
out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False)
out = F.interpolate(image, target_image_size, mode = 'nearest')
if exists(clamp_range):
out = out.clamp(*clamp_range)
@@ -1550,6 +1550,7 @@ class Unet(nn.Module):
init_conv_kernel_size = 7,
resnet_groups = 8,
num_resnet_blocks = 2,
init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
@@ -1578,7 +1579,7 @@ class Unet(nn.Module):
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
@@ -1731,7 +1732,10 @@ class Unet(nn.Module):
]))
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
out_dim_in = dim + (channels if lowres_cond else 0)
self.to_out = nn.Conv2d(out_dim_in, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
@@ -1923,7 +1927,7 @@ class Unet(nn.Module):
hiddens.append(x)
x = attn(x)
hiddens.append(x)
hiddens.append(x.contiguous())
if exists(post_downsample):
x = post_downsample(x)
@@ -1951,13 +1955,16 @@ class Unet(nn.Module):
x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t)
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
return self.to_out(x)
class LowresConditioner(nn.Module):
def __init__(
self,
downsample_first = True,
downsample_mode_nearest = False,
blur_prob = 0.5,
blur_sigma = 0.6,
blur_kernel_size = 3,
@@ -1965,8 +1972,6 @@ class LowresConditioner(nn.Module):
):
super().__init__()
self.downsample_first = downsample_first
self.downsample_mode_nearest = downsample_mode_nearest
self.input_image_range = input_image_range
self.blur_prob = blur_prob
@@ -1983,7 +1988,7 @@ class LowresConditioner(nn.Module):
blur_kernel_size = None
):
if self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = True)
# blur is only applied 50% of the time
# section 3.1 in https://arxiv.org/abs/2106.15282
@@ -2010,7 +2015,7 @@ class LowresConditioner(nn.Module):
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True)
return cond_fmap
class Decoder(nn.Module):
@@ -2033,7 +2038,6 @@ class Decoder(nn.Module):
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
blur_sigma = 0.6, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
@@ -2169,6 +2173,7 @@ class Decoder(nn.Module):
# random crop sizes (for super-resoluting unets at the end of cascade?)
self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
assert not exists(self.random_crop_sizes[0]), 'you would not need to randomly crop the image for the base unet'
# predict x0 config
@@ -2183,11 +2188,8 @@ class Decoder(nn.Module):
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.lowres_downsample_mode_nearest = lowres_downsample_mode_nearest
self.to_lowres_cond = LowresConditioner(
downsample_first = lowres_downsample_first,
downsample_mode_nearest = lowres_downsample_mode_nearest,
blur_prob = blur_prob,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
@@ -2510,7 +2512,7 @@ class Decoder(nn.Module):
shape = (batch_size, channel, image_size, image_size)
if unet.lowres_cond:
lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = self.lowres_downsample_mode_nearest)
lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True)
is_latent_diffusion = isinstance(vae, VQGanVAE)
image_size = vae.get_encoded_fmap_size(image_size)
@@ -2580,7 +2582,7 @@ class Decoder(nn.Module):
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)
image = resize_image_to(image, target_image_size, nearest = True)
if exists(random_crop_size):
aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)

View File

@@ -225,6 +225,7 @@ class UnetConfig(BaseModel):
self_attn: ListOrTuple(int)
attn_dim_head: int = 32
attn_heads: int = 16
init_cross_embed: bool = True
class Config:
extra = "allow"

View File

@@ -1 +1 @@
__version__ = '0.23.9'
__version__ = '0.24.3'