refactor blurring training augmentation to be taken care of by the decoder, with option to downsample to previous resolution before upsampling (cascading ddpm). this opens up the possibility of cascading latent ddpm

This commit is contained in:
Phil Wang
2022-04-22 11:09:17 -07:00
parent ad17c69ab6
commit 2c6c91829d
3 changed files with 78 additions and 22 deletions

View File

@@ -1,2 +1,4 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -498,7 +498,7 @@ class DiffusionPrior(nn.Module):
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -828,9 +828,6 @@ class Unet(nn.Module):
attn_dim_head = 32,
attn_heads = 8,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
blur_kernel_size = 3,
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
@@ -847,9 +844,6 @@ class Unet(nn.Module):
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions
@@ -977,13 +971,6 @@ class Unet(nn.Module):
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
# time conditioning
@@ -1060,6 +1047,44 @@ class Unet(nn.Module):
return self.final_conv(x)
class LowresConditioner(nn.Module):
def __init__(
self,
cond_upsample_mode = 'bilinear',
downsample_first = True,
blur_sigma = 0.1,
blur_kernel_size = 3,
):
super().__init__()
self.cond_upsample_mode = cond_upsample_mode
self.downsample_first = downsample_first
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
def forward(
self,
cond_fmap,
*,
target_image_size,
downsample_image_size = None,
blur_sigma = None,
blur_kernel_size = None
):
target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
return cond_fmap
class Decoder(nn.Module):
def __init__(
self,
@@ -1070,7 +1095,11 @@ class Decoder(nn.Module):
cond_drop_prob = 0.2,
loss_type = 'l1',
beta_schedule = 'cosine',
image_sizes = None # for cascading ddpm, image size at each stage
image_sizes = None, # for cascading ddpm, image size at each stage
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -1097,11 +1126,24 @@ class Decoder(nn.Module):
self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.to_lowres_cond = LowresConditioner(
cond_upsample_mode = lowres_cond_upsample_mode,
downsample_first = lowres_downsample_first,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
)
# classifier free guidance
self.cond_drop_prob = cond_drop_prob
# noise schedule
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
@@ -1116,7 +1158,7 @@ class Decoder(nn.Module):
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod = torch.cumprod(alphas, axis = 0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
@@ -1228,6 +1270,7 @@ class Decoder(nn.Module):
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
return img
def q_sample(self, x_start, t, noise=None):
@@ -1267,7 +1310,6 @@ class Decoder(nn.Module):
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
channels = self.channels
text_encodings = self.get_text_encodings(text) if exists(text) else None
@@ -1275,18 +1317,30 @@ class Decoder(nn.Module):
for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)):
with self.one_unet_in_gpu(unet = unet):
shape = (batch_size, channel, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
lowres_cond_img = self.to_lowres_cond(
img,
target_image_size = image_size
) if unet.lowres_cond else None
img = self.p_sample_loop(
unet,
(batch_size, channel, image_size, image_size),
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img
)
return img
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
unet_index = unet_number - 1
unet = self.get_unet(unet_number)
target_image_size = self.image_sizes[unet_number - 1]
target_image_size = self.image_sizes[unet_index]
b, c, h, w, device, = *image.shape, image.device
@@ -1300,7 +1354,7 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if unet_number > 1 else None
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.31',
version = '0.0.32',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',