mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
prepare for cascading diffusion in unet, save the full progressive upsampling architecture to be built next week
This commit is contained in:
@@ -11,7 +11,7 @@ from einops.layers.torch import Rearrange
|
|||||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||||
from einops_exts.torch import EinopsToAndFrom
|
from einops_exts.torch import EinopsToAndFrom
|
||||||
|
|
||||||
from kornia.filters import filter2d
|
from kornia.filters.gaussian import GaussianBlur2d
|
||||||
|
|
||||||
from dalle2_pytorch.tokenizer import tokenizer
|
from dalle2_pytorch.tokenizer import tokenizer
|
||||||
|
|
||||||
@@ -625,17 +625,6 @@ def Upsample(dim):
|
|||||||
def Downsample(dim):
|
def Downsample(dim):
|
||||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||||
|
|
||||||
class Blur(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
filt = torch.Tensor([1, 2, 1])
|
|
||||||
self.register_buffer('filt', filt)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
filt = self.filt
|
|
||||||
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
|
||||||
return filter2d(x, filt, normalized = True)
|
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -769,11 +758,25 @@ class Unet(nn.Module):
|
|||||||
out_dim = None,
|
out_dim = None,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
channels = 3,
|
channels = 3,
|
||||||
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||||
|
lowres_cond_upsample_mode = 'bilinear',
|
||||||
|
blur_sigma = 0.1
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# for eventual cascading diffusion
|
||||||
|
|
||||||
|
self.lowres_cond = lowres_cond
|
||||||
|
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
||||||
|
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
|
||||||
|
|
||||||
|
# determine dimensions
|
||||||
|
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
||||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||||
|
|
||||||
|
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
|
|
||||||
# time, image embeddings, and optional text encoding
|
# time, image embeddings, and optional text encoding
|
||||||
@@ -856,12 +859,30 @@ class Unet(nn.Module):
|
|||||||
time,
|
time,
|
||||||
*,
|
*,
|
||||||
image_embed,
|
image_embed,
|
||||||
|
lowres_cond_img = None,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
cond_drop_prob = 0.
|
cond_drop_prob = 0.
|
||||||
):
|
):
|
||||||
batch_size, device = x.shape[0], x.device
|
batch_size, device = x.shape[0], x.device
|
||||||
|
|
||||||
|
# add low resolution conditioning, if present
|
||||||
|
|
||||||
|
assert not self.lowres_cond and not exists(lowres_cond_img), 'low resolution conditioning image must be present'
|
||||||
|
|
||||||
|
if exists(lowres_cond_img):
|
||||||
|
if self.training:
|
||||||
|
# when training, blur the low resolution conditional image
|
||||||
|
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
|
||||||
|
|
||||||
|
lowres_cond_img = F.interpolate(lowres_cond_img, size = x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||||
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||||
|
|
||||||
|
# time conditioning
|
||||||
|
|
||||||
time_tokens = self.time_mlp(time)
|
time_tokens = self.time_mlp(time)
|
||||||
|
|
||||||
|
# conditional dropout
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user