Compare commits

...

2 Commits

Author SHA1 Message Date
Phil Wang
1992d25cad project management 2022-05-04 11:18:54 -07:00
Phil Wang
5b619c2fd5 make sure some hyperparameters for unet block is configurable 2022-05-04 11:18:32 -07:00
3 changed files with 6 additions and 4 deletions

View File

@@ -821,6 +821,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [x] bring in tools to train vqgan-vae
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo)
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
@@ -835,7 +836,6 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
## Citations

View File

@@ -1196,6 +1196,8 @@ class Unet(nn.Module):
init_dim = None,
init_conv_kernel_size = 7,
block_type = 'resnet',
block_resnet_groups = 8,
block_convnext_mult = 2,
**kwargs
):
super().__init__()
@@ -1274,9 +1276,9 @@ class Unet(nn.Module):
# whether to use resnet or the (improved?) convnext blocks
if block_type == 'resnet':
block_klass = ResnetBlock
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
elif block_type == 'convnext':
block_klass = ConvNextBlock
block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
else:
raise ValueError(f'unimplemented block type {block_type}')

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.95',
version = '0.0.96',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',