mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
allow for config driven creation of clip-less diffusion prior
This commit is contained in:
@@ -3,7 +3,7 @@ from torchvision import transforms as T
|
||||
from pydantic import BaseModel, validator, root_validator
|
||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork
|
||||
|
||||
# helper functions
|
||||
|
||||
@@ -18,6 +18,43 @@ def ListOrTuple(inner_type):
|
||||
|
||||
# pydantic classes
|
||||
|
||||
class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim: int
|
||||
depth: int
|
||||
num_timesteps: int = None
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
num_text_embeds: int = 1
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
final_proj: bool = True
|
||||
normformer: bool = False
|
||||
rotary_emb: bool = True
|
||||
|
||||
class DiffusionPriorConfig(BaseModel):
|
||||
# only clip-less diffusion prior config for now
|
||||
net: DiffusionPriorNetworkConfig
|
||||
image_embed_dim: int
|
||||
image_size: int
|
||||
image_channels: int = 3
|
||||
timesteps: int = 1000
|
||||
cond_drop_prob: float = 0.
|
||||
loss_type: str = 'l2'
|
||||
predict_x_start: bool = True
|
||||
beta_schedule: str = 'cosine'
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
|
||||
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
class UnetConfig(BaseModel):
|
||||
dim: int
|
||||
dim_mults: ListOrTuple(int)
|
||||
@@ -31,7 +68,7 @@ class UnetConfig(BaseModel):
|
||||
extra = "allow"
|
||||
|
||||
class DecoderConfig(BaseModel):
|
||||
unets: Union[List[UnetConfig], Tuple[UnetConfig]]
|
||||
unets: ListOrTuple(UnetConfig)
|
||||
image_size: int = None
|
||||
image_sizes: ListOrTuple(int) = None
|
||||
channels: int = 3
|
||||
@@ -39,6 +76,8 @@ class DecoderConfig(BaseModel):
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: str = 'cosine'
|
||||
learned_variance: bool = True
|
||||
image_cond_drop_prob: float = 0.1
|
||||
text_cond_drop_prob: float = 0.5
|
||||
|
||||
def create(self):
|
||||
decoder_kwargs = self.dict()
|
||||
|
||||
Reference in New Issue
Block a user