make kernel size and sigma for gaussian blur for cascading DDPM overridable at forward. also make sure unets are wrapped in a modulelist so that at sample time, blurring does not happen

This commit is contained in:
Phil Wang
2022-04-18 12:04:31 -07:00
parent 6cddefad26
commit 00ae50999b
3 changed files with 14 additions and 7 deletions

View File

@@ -1,6 +1,7 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
import torch
import torch.nn.functional as F
@@ -11,7 +12,7 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters.gaussian import GaussianBlur2d
from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer
@@ -811,6 +812,7 @@ class Unet(nn.Module):
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
blur_kernel_size = 3,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
):
super().__init__()
@@ -819,7 +821,8 @@ class Unet(nn.Module):
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions
@@ -915,7 +918,9 @@ class Unet(nn.Module):
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.
cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
):
batch_size, device = x.shape[0], x.device
@@ -926,7 +931,9 @@ class Unet(nn.Module):
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
@@ -1014,7 +1021,7 @@ class Decoder(nn.Module):
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
self.unets = cast_tuple(unet)
self.unets = nn.ModuleList(unet)
image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = tuple(sorted(set(image_sizes)))