From 00ae50999bafaad92b5588eb32894d3bcac6af07 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 18 Apr 2022 12:04:31 -0700 Subject: [PATCH] make kernel size and sigma for gaussian blur for cascading DDPM overridable at forward. also make sure unets are wrapped in a modulelist so that at sample time, blurring does not happen --- README.md | 2 +- dalle2_pytorch/dalle2_pytorch.py | 17 ++++++++++++----- setup.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 91dcf8f..95ae168 100644 --- a/README.md +++ b/README.md @@ -224,7 +224,7 @@ unet2 = Unet( dim_mults = (1, 2, 4, 8, 16) ).cuda() -# decoder, which contains the unet and clip +# decoder, which contains the unet(s) and clip decoder = Decoder( clip = clip, diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index ea69b5b..4f84123 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1,6 +1,7 @@ import math from tqdm import tqdm from inspect import isfunction +from functools import partial import torch import torch.nn.functional as F @@ -11,7 +12,7 @@ from einops.layers.torch import Rearrange from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts.torch import EinopsToAndFrom -from kornia.filters.gaussian import GaussianBlur2d +from kornia.filters import gaussian_blur2d from dalle2_pytorch.tokenizer import tokenizer @@ -811,6 +812,7 @@ class Unet(nn.Module): lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond_upsample_mode = 'bilinear', blur_sigma = 0.1, + blur_kernel_size = 3, attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) ): super().__init__() @@ -819,7 +821,8 @@ class Unet(nn.Module): self.lowres_cond = lowres_cond self.lowres_cond_upsample_mode = lowres_cond_upsample_mode - self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma)) + self.lowres_blur_kernel_size = blur_kernel_size + self.lowres_blur_sigma = blur_sigma # determine dimensions @@ -915,7 +918,9 @@ class Unet(nn.Module): image_embed, lowres_cond_img = None, text_encodings = None, - cond_drop_prob = 0. + cond_drop_prob = 0., + blur_sigma = None, + blur_kernel_size = None ): batch_size, device = x.shape[0], x.device @@ -926,7 +931,9 @@ class Unet(nn.Module): if exists(lowres_cond_img): if self.training: # when training, blur the low resolution conditional image - lowres_cond_img = self.lowres_cond_blur(lowres_cond_img) + blur_sigma = default(blur_sigma, self.lowres_blur_sigma) + blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size) + lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode) x = torch.cat((x, lowres_cond_img), dim = 1) @@ -1014,7 +1021,7 @@ class Decoder(nn.Module): self.clip_image_size = clip.image_size self.channels = clip.image_channels - self.unets = cast_tuple(unet) + self.unets = nn.ModuleList(unet) image_sizes = default(image_sizes, (clip.image_size,)) image_sizes = tuple(sorted(set(image_sizes))) diff --git a/setup.py b/setup.py index fcf89e3..8e44f4d 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.20', + version = '0.0.21', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',