mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
allow for config driven creation of clip-less diffusion prior
This commit is contained in:
@@ -1084,6 +1084,7 @@ This library would not have gotten to this working state without the help of
|
||||
- [x] use pydantic for config drive training
|
||||
- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
|
||||
- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
@@ -1097,7 +1098,6 @@ This library would not have gotten to this working state without the help of
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||
- [ ] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from torchvision import transforms as T
|
||||
from pydantic import BaseModel, validator, root_validator
|
||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork
|
||||
|
||||
# helper functions
|
||||
|
||||
@@ -18,6 +18,43 @@ def ListOrTuple(inner_type):
|
||||
|
||||
# pydantic classes
|
||||
|
||||
class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim: int
|
||||
depth: int
|
||||
num_timesteps: int = None
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
num_text_embeds: int = 1
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
final_proj: bool = True
|
||||
normformer: bool = False
|
||||
rotary_emb: bool = True
|
||||
|
||||
class DiffusionPriorConfig(BaseModel):
|
||||
# only clip-less diffusion prior config for now
|
||||
net: DiffusionPriorNetworkConfig
|
||||
image_embed_dim: int
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
beta_schedule: str = 'cosine'
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
|
||||
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: ListOrTuple(int)
|
||||
@@ -31,7 +68,7 @@ class UnetConfig(BaseModel):
|
||||
extra = "allow"
|
||||
|
||||
class DecoderConfig(BaseModel):
|
||||
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
|
||||
unets: ListOrTuple(UnetConfig)
|
||||
image_size: int = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
channels: int = 3
|
||||
@@ -39,6 +76,8 @@ class DecoderConfig(BaseModel):
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: str = 'cosine'
|
||||
learned_variance: bool = True
|
||||
image_cond_drop_prob: float = 0.1
|
||||
text_cond_drop_prob: float = 0.5
|
||||
|
||||
def create(self):
|
||||
decoder_kwargs = self.dict()
|
||||
|
||||
Reference in New Issue
Block a user