From 4d346e98d995a34e07628a6b1071eb36a30756c4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 22 May 2022 20:36:20 -0700 Subject: [PATCH] allow for config driven creation of clip-less diffusion prior --- README.md | 2 +- dalle2_pytorch/train_configs.py | 43 +++++++++++++++++++++++++++++++-- setup.py | 2 +- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 47b24f4..b621915 100644 --- a/README.md +++ b/README.md @@ -1084,6 +1084,7 @@ This library would not have gotten to this working state without the help of - [x] use pydantic for config drive training - [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number) - [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes +- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] train on a toy task, offer in colab @@ -1097,7 +1098,6 @@ This library would not have gotten to this working state without the help of - [ ] decoder needs one day worth of refactor for tech debt - [ ] allow for unet to be able to condition non-cross attention style as well - [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89 -- [ ] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs ## Citations diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 2c8765a..695f5ee 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -3,7 +3,7 @@ from torchvision import transforms as T from pydantic import BaseModel, validator, root_validator from typing import List, Iterable, Optional, Union, Tuple, Dict, Any -from dalle2_pytorch.dalle2_pytorch import Unet, Decoder +from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork # helper functions @@ -18,6 +18,43 @@ def ListOrTuple(inner_type): # pydantic classes +class DiffusionPriorNetworkConfig(BaseModel): + dim: int + depth: int + num_timesteps: int = None + num_time_embeds: int = 1 + num_image_embeds: int = 1 + num_text_embeds: int = 1 + dim_head: int = 64 + heads: int = 8 + ff_mult: int = 4 + norm_out: bool = True + attn_dropout: float = 0. + ff_dropout: float = 0. + final_proj: bool = True + normformer: bool = False + rotary_emb: bool = True + +class DiffusionPriorConfig(BaseModel): + # only clip-less diffusion prior config for now + net: DiffusionPriorNetworkConfig + image_embed_dim: int + image_size: int + image_channels: int = 3 + timesteps: int = 1000 + cond_drop_prob: float = 0. + loss_type: str = 'l2' + predict_x_start: bool = True + beta_schedule: str = 'cosine' + + def create(self): + kwargs = self.dict() + diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net')) + return DiffusionPrior(net = diffusion_prior_network, **kwargs) + + class Config: + extra = "allow" + class UnetConfig(BaseModel): dim: int dim_mults: ListOrTuple(int) @@ -31,7 +68,7 @@ class UnetConfig(BaseModel): extra = "allow" class DecoderConfig(BaseModel): - unets: Union[List[UnetConfig], Tuple[UnetConfig]] + unets: ListOrTuple(UnetConfig) image_size: int = None image_sizes: ListOrTuple(int) = None channels: int = 3 @@ -39,6 +76,8 @@ class DecoderConfig(BaseModel): loss_type: str = 'l2' beta_schedule: str = 'cosine' learned_variance: bool = True + image_cond_drop_prob: float = 0.1 + text_cond_drop_prob: float = 0.5 def create(self): decoder_kwargs = self.dict() diff --git a/setup.py b/setup.py index 0b14740..60ca0e9 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.4.8', + version = '0.4.9', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',