mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
make kernel size and sigma for gaussian blur for cascading DDPM overridable at forward. also make sure unets are wrapped in a modulelist so that at sample time, blurring does not happen
This commit is contained in:
@@ -224,7 +224,7 @@ unet2 = Unet(
|
|||||||
dim_mults = (1, 2, 4, 8, 16)
|
dim_mults = (1, 2, 4, 8, 16)
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# decoder, which contains the unet and clip
|
# decoder, which contains the unet(s) and clip
|
||||||
|
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
clip = clip,
|
clip = clip,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -11,7 +12,7 @@ from einops.layers.torch import Rearrange
|
|||||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||||
from einops_exts.torch import EinopsToAndFrom
|
from einops_exts.torch import EinopsToAndFrom
|
||||||
|
|
||||||
from kornia.filters.gaussian import GaussianBlur2d
|
from kornia.filters import gaussian_blur2d
|
||||||
|
|
||||||
from dalle2_pytorch.tokenizer import tokenizer
|
from dalle2_pytorch.tokenizer import tokenizer
|
||||||
|
|
||||||
@@ -811,6 +812,7 @@ class Unet(nn.Module):
|
|||||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||||
lowres_cond_upsample_mode = 'bilinear',
|
lowres_cond_upsample_mode = 'bilinear',
|
||||||
blur_sigma = 0.1,
|
blur_sigma = 0.1,
|
||||||
|
blur_kernel_size = 3,
|
||||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -819,7 +821,8 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.lowres_cond = lowres_cond
|
self.lowres_cond = lowres_cond
|
||||||
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
||||||
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
|
self.lowres_blur_kernel_size = blur_kernel_size
|
||||||
|
self.lowres_blur_sigma = blur_sigma
|
||||||
|
|
||||||
# determine dimensions
|
# determine dimensions
|
||||||
|
|
||||||
@@ -915,7 +918,9 @@ class Unet(nn.Module):
|
|||||||
image_embed,
|
image_embed,
|
||||||
lowres_cond_img = None,
|
lowres_cond_img = None,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
cond_drop_prob = 0.
|
cond_drop_prob = 0.,
|
||||||
|
blur_sigma = None,
|
||||||
|
blur_kernel_size = None
|
||||||
):
|
):
|
||||||
batch_size, device = x.shape[0], x.device
|
batch_size, device = x.shape[0], x.device
|
||||||
|
|
||||||
@@ -926,7 +931,9 @@ class Unet(nn.Module):
|
|||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
if self.training:
|
if self.training:
|
||||||
# when training, blur the low resolution conditional image
|
# when training, blur the low resolution conditional image
|
||||||
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
|
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
|
||||||
|
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
|
||||||
|
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||||
|
|
||||||
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||||
@@ -1014,7 +1021,7 @@ class Decoder(nn.Module):
|
|||||||
self.clip_image_size = clip.image_size
|
self.clip_image_size = clip.image_size
|
||||||
self.channels = clip.image_channels
|
self.channels = clip.image_channels
|
||||||
|
|
||||||
self.unets = cast_tuple(unet)
|
self.unets = nn.ModuleList(unet)
|
||||||
image_sizes = default(image_sizes, (clip.image_size,))
|
image_sizes = default(image_sizes, (clip.image_size,))
|
||||||
image_sizes = tuple(sorted(set(image_sizes)))
|
image_sizes = tuple(sorted(set(image_sizes)))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user