Compare commits

..

3 Commits

3 changed files with 29 additions and 7 deletions

View File

@@ -218,13 +218,12 @@ unet1 = Unet(
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
lowres_cond = True, # subsequent unets must have this turned on (and first unet must have this turned off)
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unet and clip
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
@@ -349,8 +348,7 @@ unet2 = Unet(
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
lowres_cond = True
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
@@ -414,6 +412,7 @@ Offer training wrappers
- [ ] make unet more configurable
- [ ] figure out some factory methods to make cascading unet instantiations less error-prone
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] become an expert with unets, port learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)

View File

@@ -12,7 +12,6 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters.gaussian import GaussianBlur2d
from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer
@@ -817,6 +816,11 @@ class Unet(nn.Module):
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
del self._locals['self']
del self._locals['__class__']
# for eventual cascading diffusion
@@ -897,6 +901,15 @@ class Unet(nn.Module):
nn.Conv2d(dim, out_dim, 1)
)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def force_lowres_cond(self, lowres_cond):
if lowres_cond == self.lowres_cond:
return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
return self.__class__(**updated_kwargs)
def forward_with_cond_scale(
self,
*args,
@@ -1022,7 +1035,17 @@ class Decoder(nn.Module):
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
self.unets = nn.ModuleList(unet)
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
self.unets = nn.ModuleList([])
for ind, one_unet in enumerate(cast_tuple(unet)):
is_first = ind == 0
one_unet = one_unet.force_lowres_cond(not is_first)
self.unets.append(one_unet)
# unet image sizes
image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = tuple(sorted(set(image_sizes)))

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.21',
version = '0.0.22',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',