mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
wrap up cross embed layer feature
This commit is contained in:
@@ -1003,6 +1003,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [x] cross embed layers for downsampling, as an option
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
@@ -1015,7 +1016,6 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] cross embed layers for downsampling, as an option
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -41,9 +41,6 @@ def exists(val):
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def is_odd(n):
|
||||
return (n % 2) == 1
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
@@ -1235,12 +1232,13 @@ class CrossEmbedLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
kernel_sizes,
|
||||
dim_out = None,
|
||||
stride = 2
|
||||
):
|
||||
super().__init__()
|
||||
assert all([*map(is_odd, kernel_sizes)])
|
||||
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
|
||||
dim_out = default(dim_out, dim_in)
|
||||
|
||||
kernel_sizes = sorted(kernel_sizes)
|
||||
num_scales = len(kernel_sizes)
|
||||
@@ -1282,6 +1280,8 @@ class Unet(nn.Module):
|
||||
init_conv_kernel_size = 7,
|
||||
resnet_groups = 8,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1302,7 +1302,7 @@ class Unet(nn.Module):
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
|
||||
self.init_conv = CrossEmbedLayer(init_channels, init_dim, init_cross_embed_kernel_sizes, stride = 1)
|
||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
||||
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
@@ -1362,6 +1362,12 @@ class Unet(nn.Module):
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
|
||||
# downsample klass
|
||||
|
||||
downsample_klass = Downsample
|
||||
if cross_embed_downsample:
|
||||
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
@@ -1377,7 +1383,7 @@ class Unet(nn.Module):
|
||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
Reference in New Issue
Block a user