mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 00:14:27 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
960a79857b | ||
|
|
7214df472d | ||
|
|
00ae50999b |
@@ -218,13 +218,12 @@ unet1 = Unet(
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
lowres_cond = True, # subsequent unets must have this turned on (and first unet must have this turned off)
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
# decoder, which contains the unet and clip
|
||||
# decoder, which contains the unet(s) and clip
|
||||
|
||||
decoder = Decoder(
|
||||
clip = clip,
|
||||
@@ -349,8 +348,7 @@ unet2 = Unet(
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
lowres_cond = True
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
@@ -414,6 +412,7 @@ Offer training wrappers
|
||||
- [ ] make unet more configurable
|
||||
- [ ] figure out some factory methods to make cascading unet instantiations less error-prone
|
||||
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [ ] become an expert with unets, port learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
|
||||
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -11,7 +12,7 @@ from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from kornia.filters.gaussian import GaussianBlur2d
|
||||
from kornia.filters import gaussian_blur2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
|
||||
@@ -811,15 +812,22 @@ class Unet(nn.Module):
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_cond_upsample_mode = 'bilinear',
|
||||
blur_sigma = 0.1,
|
||||
blur_kernel_size = 3,
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
):
|
||||
super().__init__()
|
||||
# save locals to take care of some hyperparameters for cascading DDPM
|
||||
|
||||
self._locals = locals()
|
||||
del self._locals['self']
|
||||
del self._locals['__class__']
|
||||
|
||||
# for eventual cascading diffusion
|
||||
|
||||
self.lowres_cond = lowres_cond
|
||||
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
||||
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
|
||||
self.lowres_blur_kernel_size = blur_kernel_size
|
||||
self.lowres_blur_sigma = blur_sigma
|
||||
|
||||
# determine dimensions
|
||||
|
||||
@@ -893,6 +901,15 @@ class Unet(nn.Module):
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
)
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def force_lowres_cond(self, lowres_cond):
|
||||
if lowres_cond == self.lowres_cond:
|
||||
return self
|
||||
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
|
||||
return self.__class__(**updated_kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
*args,
|
||||
@@ -915,7 +932,9 @@ class Unet(nn.Module):
|
||||
image_embed,
|
||||
lowres_cond_img = None,
|
||||
text_encodings = None,
|
||||
cond_drop_prob = 0.
|
||||
cond_drop_prob = 0.,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
|
||||
@@ -926,7 +945,9 @@ class Unet(nn.Module):
|
||||
if exists(lowres_cond_img):
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
|
||||
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
|
||||
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
@@ -1014,7 +1035,17 @@ class Decoder(nn.Module):
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
|
||||
self.unets = cast_tuple(unet)
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
for ind, one_unet in enumerate(cast_tuple(unet)):
|
||||
is_first = ind == 0
|
||||
one_unet = one_unet.force_lowres_cond(not is_first)
|
||||
self.unets.append(one_unet)
|
||||
|
||||
# unet image sizes
|
||||
|
||||
image_sizes = default(image_sizes, (clip.image_size,))
|
||||
image_sizes = tuple(sorted(set(image_sizes)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user