prepare for cascading diffusion in unet, save the full progressive upsampling architecture to be built next week

This commit is contained in:
Phil Wang
2022-04-15 07:03:28 -07:00
parent bece206699
commit c400d8758c
2 changed files with 35 additions and 14 deletions

View File

@@ -11,7 +11,7 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters import filter2d
from kornia.filters.gaussian import GaussianBlur2d
from dalle2_pytorch.tokenizer import tokenizer
@@ -625,17 +625,6 @@ def Upsample(dim):
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -769,11 +758,25 @@ class Unet(nn.Module):
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1
):
super().__init__()
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
# determine dimensions
self.channels = channels
dims = [channels, *map(lambda m: dim * m, dim_mults)]
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time, image embeddings, and optional text encoding
@@ -856,12 +859,30 @@ class Unet(nn.Module):
time,
*,
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.
):
batch_size, device = x.shape[0], x.device
# add low resolution conditioning, if present
assert not self.lowres_cond and not exists(lowres_cond_img), 'low resolution conditioning image must be present'
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
lowres_cond_img = F.interpolate(lowres_cond_img, size = x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
# time conditioning
time_tokens = self.time_mlp(time)
# conditional dropout
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.16',
version = '0.0.17',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',