diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index f85a61f..4b2a164 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -11,7 +11,7 @@ from einops.layers.torch import Rearrange from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts.torch import EinopsToAndFrom -from kornia.filters import filter2d +from kornia.filters.gaussian import GaussianBlur2d from dalle2_pytorch.tokenizer import tokenizer @@ -625,17 +625,6 @@ def Upsample(dim): def Downsample(dim): return nn.Conv2d(dim, dim, 4, 2, 1) -class Blur(nn.Module): - def __init__(self): - super().__init__() - filt = torch.Tensor([1, 2, 1]) - self.register_buffer('filt', filt) - - def forward(self, x): - filt = self.filt - filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1') - return filter2d(x, filt, normalized = True) - class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() @@ -769,11 +758,25 @@ class Unet(nn.Module): out_dim = None, dim_mults=(1, 2, 4, 8), channels = 3, + lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ + lowres_cond_upsample_mode = 'bilinear', + blur_sigma = 0.1 ): super().__init__() + + # for eventual cascading diffusion + + self.lowres_cond = lowres_cond + self.lowres_cond_upsample_mode = lowres_cond_upsample_mode + self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma)) + + # determine dimensions + self.channels = channels - dims = [channels, *map(lambda m: dim * m, dim_mults)] + init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis + + dims = [init_channels, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # time, image embeddings, and optional text encoding @@ -856,12 +859,30 @@ class Unet(nn.Module): time, *, image_embed, + lowres_cond_img = None, text_encodings = None, cond_drop_prob = 0. ): batch_size, device = x.shape[0], x.device + + # add low resolution conditioning, if present + + assert not self.lowres_cond and not exists(lowres_cond_img), 'low resolution conditioning image must be present' + + if exists(lowres_cond_img): + if self.training: + # when training, blur the low resolution conditional image + lowres_cond_img = self.lowres_cond_blur(lowres_cond_img) + + lowres_cond_img = F.interpolate(lowres_cond_img, size = x.shape[-2:], mode = self.lowres_cond_upsample_mode) + x = torch.cat((x, lowres_cond_img), dim = 1) + + # time conditioning + time_tokens = self.time_mlp(time) + # conditional dropout + cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device) cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1') diff --git a/setup.py b/setup.py index 86c0053..e0844f0 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.16', + version = '0.0.17', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',