mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
refactor blurring training augmentation to be taken care of by the decoder, with option to downsample to previous resolution before upsampling (cascading ddpm). this opens up the possibility of cascading latent ddpm
This commit is contained in:
@@ -1,2 +1,4 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from x_clip import CLIP
|
||||
|
||||
@@ -498,7 +498,7 @@ class DiffusionPrior(nn.Module):
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
@@ -828,9 +828,6 @@ class Unet(nn.Module):
|
||||
attn_dim_head = 32,
|
||||
attn_heads = 8,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_cond_upsample_mode = 'bilinear',
|
||||
blur_sigma = 0.1,
|
||||
blur_kernel_size = 3,
|
||||
sparse_attn = False,
|
||||
sparse_attn_window = 8, # window size for sparse attention
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
@@ -847,9 +844,6 @@ class Unet(nn.Module):
|
||||
# for eventual cascading diffusion
|
||||
|
||||
self.lowres_cond = lowres_cond
|
||||
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
||||
self.lowres_blur_kernel_size = blur_kernel_size
|
||||
self.lowres_blur_sigma = blur_sigma
|
||||
|
||||
# determine dimensions
|
||||
|
||||
@@ -977,13 +971,6 @@ class Unet(nn.Module):
|
||||
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
|
||||
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
# time conditioning
|
||||
@@ -1060,6 +1047,44 @@ class Unet(nn.Module):
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cond_upsample_mode = 'bilinear',
|
||||
downsample_first = True,
|
||||
blur_sigma = 0.1,
|
||||
blur_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.cond_upsample_mode = cond_upsample_mode
|
||||
self.downsample_first = downsample_first
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_kernel_size = blur_kernel_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
cond_fmap,
|
||||
*,
|
||||
target_image_size,
|
||||
downsample_image_size = None,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
):
|
||||
target_image_size = cast_tuple(target_image_size, 2)
|
||||
|
||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
return cond_fmap
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1070,7 +1095,11 @@ class Decoder(nn.Module):
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
beta_schedule = 'cosine',
|
||||
image_sizes = None # for cascading ddpm, image size at each stage
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
@@ -1097,11 +1126,24 @@ class Decoder(nn.Module):
|
||||
self.image_sizes = image_sizes
|
||||
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
|
||||
|
||||
# cascading ddpm related stuff
|
||||
|
||||
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
||||
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
||||
|
||||
self.to_lowres_cond = LowresConditioner(
|
||||
cond_upsample_mode = lowres_cond_upsample_mode,
|
||||
downsample_first = lowres_downsample_first,
|
||||
blur_sigma = blur_sigma,
|
||||
blur_kernel_size = blur_kernel_size,
|
||||
)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
# noise schedule
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -1116,7 +1158,7 @@ class Decoder(nn.Module):
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
@@ -1228,6 +1270,7 @@ class Decoder(nn.Module):
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
return img
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
@@ -1267,7 +1310,6 @@ class Decoder(nn.Module):
|
||||
@eval_decorator
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
channels = self.channels
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
@@ -1275,18 +1317,30 @@ class Decoder(nn.Module):
|
||||
|
||||
for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)):
|
||||
with self.one_unet_in_gpu(unet = unet):
|
||||
shape = (batch_size, channel, image_size, image_size)
|
||||
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
||||
lowres_cond_img = self.to_lowres_cond(
|
||||
img,
|
||||
target_image_size = image_size
|
||||
) if unet.lowres_cond else None
|
||||
|
||||
img = self.p_sample_loop(
|
||||
unet,
|
||||
(batch_size, channel, image_size, image_size),
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
unet_index = unet_number - 1
|
||||
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
target_image_size = self.image_sizes[unet_number - 1]
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
@@ -1300,7 +1354,7 @@ class Decoder(nn.Module):
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
|
||||
lowres_cond_img = image if unet_number > 1 else None
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
ddpm_image = resize_image_to(image, target_image_size)
|
||||
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user