make kernel size and sigma for gaussian blur for cascading DDPM overridable at forward. also make sure unets are wrapped in a modulelist so that at sample time, blurring does not happen

This commit is contained in:
Phil Wang
2022-04-18 12:04:31 -07:00
parent 6cddefad26
commit 00ae50999b
3 changed files with 14 additions and 7 deletions

View File

@@ -224,7 +224,7 @@ unet2 = Unet(
dim_mults = (1, 2, 4, 8, 16) dim_mults = (1, 2, 4, 8, 16)
).cuda() ).cuda()
# decoder, which contains the unet and clip # decoder, which contains the unet(s) and clip
decoder = Decoder( decoder = Decoder(
clip = clip, clip = clip,

View File

@@ -1,6 +1,7 @@
import math import math
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -11,7 +12,7 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom from einops_exts.torch import EinopsToAndFrom
from kornia.filters.gaussian import GaussianBlur2d from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
@@ -811,6 +812,7 @@ class Unet(nn.Module):
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear', lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1, blur_sigma = 0.1,
blur_kernel_size = 3,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
): ):
super().__init__() super().__init__()
@@ -819,7 +821,8 @@ class Unet(nn.Module):
self.lowres_cond = lowres_cond self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma)) self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions # determine dimensions
@@ -915,7 +918,9 @@ class Unet(nn.Module):
image_embed, image_embed,
lowres_cond_img = None, lowres_cond_img = None,
text_encodings = None, text_encodings = None,
cond_drop_prob = 0. cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
): ):
batch_size, device = x.shape[0], x.device batch_size, device = x.shape[0], x.device
@@ -926,7 +931,9 @@ class Unet(nn.Module):
if exists(lowres_cond_img): if exists(lowres_cond_img):
if self.training: if self.training:
# when training, blur the low resolution conditional image # when training, blur the low resolution conditional image
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img) blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode) lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1) x = torch.cat((x, lowres_cond_img), dim = 1)
@@ -1014,7 +1021,7 @@ class Decoder(nn.Module):
self.clip_image_size = clip.image_size self.clip_image_size = clip.image_size
self.channels = clip.image_channels self.channels = clip.image_channels
self.unets = cast_tuple(unet) self.unets = nn.ModuleList(unet)
image_sizes = default(image_sizes, (clip.image_size,)) image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = tuple(sorted(set(image_sizes))) image_sizes = tuple(sorted(set(image_sizes)))

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.20', version = '0.0.21',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',