mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
add more support for configuring prior (#113)
This commit is contained in:
70
configs/train_prior_config.example.json
Normal file
70
configs/train_prior_config.example.json
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
{
|
||||||
|
"prior": {
|
||||||
|
"clip": {
|
||||||
|
"make": "x-clip",
|
||||||
|
"model": "ViT-L/14",
|
||||||
|
"base_model_kwargs": {
|
||||||
|
"dim_text": 768,
|
||||||
|
"dim_image": 768,
|
||||||
|
"dim_latent": 768
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"net": {
|
||||||
|
"dim": 768,
|
||||||
|
"depth": 12,
|
||||||
|
"num_timesteps": 1000,
|
||||||
|
"num_time_embeds": 1,
|
||||||
|
"num_image_embeds": 1,
|
||||||
|
"num_text_embeds": 1,
|
||||||
|
"dim_head": 64,
|
||||||
|
"heads": 12,
|
||||||
|
"ff_mult": 4,
|
||||||
|
"norm_out": true,
|
||||||
|
"attn_dropout": 0.0,
|
||||||
|
"ff_dropout": 0.0,
|
||||||
|
"final_proj": true,
|
||||||
|
"normformer": true,
|
||||||
|
"rotary_emb": true
|
||||||
|
},
|
||||||
|
"image_embed_dim": 768,
|
||||||
|
"image_size": 224,
|
||||||
|
"image_channels": 3,
|
||||||
|
"timesteps": 1000,
|
||||||
|
"cond_drop_prob": 0.1,
|
||||||
|
"loss_type": "l2",
|
||||||
|
"predict_x_start": true,
|
||||||
|
"beta_schedule": "cosine",
|
||||||
|
"condition_on_text_encodings": true
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
|
||||||
|
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
|
||||||
|
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
|
||||||
|
"batch_size": 256,
|
||||||
|
"splits": {
|
||||||
|
"train": 0.9,
|
||||||
|
"val": 1e-7,
|
||||||
|
"test": 0.0999999
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"train": {
|
||||||
|
"epochs": 1,
|
||||||
|
"lr": 1.1e-4,
|
||||||
|
"wd": 6.02e-2,
|
||||||
|
"max_grad_norm": 0.5,
|
||||||
|
"use_ema": true,
|
||||||
|
"amp": false,
|
||||||
|
"save_every": 10000
|
||||||
|
},
|
||||||
|
"load": {
|
||||||
|
"source": null,
|
||||||
|
"resume": false
|
||||||
|
},
|
||||||
|
"tracker": {
|
||||||
|
"tracker_type": "wandb",
|
||||||
|
"data_path": "./prior_checkpoints",
|
||||||
|
"wandb_entity": "laion",
|
||||||
|
"wandb_project": "diffusion-prior",
|
||||||
|
"verbose": true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,7 +3,18 @@ from torchvision import transforms as T
|
|||||||
from pydantic import BaseModel, validator, root_validator
|
from pydantic import BaseModel, validator, root_validator
|
||||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork
|
from x_clip import CLIP as XCLIP
|
||||||
|
from coca_pytorch import CoCa
|
||||||
|
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import (
|
||||||
|
CoCaAdapter,
|
||||||
|
OpenAIClipAdapter,
|
||||||
|
Unet,
|
||||||
|
Decoder,
|
||||||
|
DiffusionPrior,
|
||||||
|
DiffusionPriorNetwork,
|
||||||
|
XClipAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
@@ -16,7 +27,44 @@ def default(val, d):
|
|||||||
def ListOrTuple(inner_type):
|
def ListOrTuple(inner_type):
|
||||||
return Union[List[inner_type], Tuple[inner_type]]
|
return Union[List[inner_type], Tuple[inner_type]]
|
||||||
|
|
||||||
# pydantic classes
|
# general pydantic classes
|
||||||
|
|
||||||
|
class TrainSplitConfig(BaseModel):
|
||||||
|
train: float = 0.75
|
||||||
|
val: float = 0.15
|
||||||
|
test: float = 0.1
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_all(cls, fields):
|
||||||
|
actual_sum = sum([*fields.values()])
|
||||||
|
if actual_sum != 1.:
|
||||||
|
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||||
|
return fields
|
||||||
|
|
||||||
|
class TrackerConfig(BaseModel):
|
||||||
|
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||||
|
data_path: str = './models' # The path where files will be saved locally
|
||||||
|
init_config: Dict[str, Any] = None
|
||||||
|
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||||
|
wandb_project: str = ''
|
||||||
|
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||||
|
|
||||||
|
# diffusion prior pydantic classes
|
||||||
|
|
||||||
|
class AdapterConfig(BaseModel):
|
||||||
|
make: str = "openai"
|
||||||
|
model: str = "ViT-L/14"
|
||||||
|
base_model_kwargs: Dict[str, Any] = None
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
if self.make == "openai":
|
||||||
|
return OpenAIClipAdapter(self.model)
|
||||||
|
elif self.make == "x-clip":
|
||||||
|
return XClipAdapter(XCLIP(**self.base_model_kwargs))
|
||||||
|
elif self.make == "coca":
|
||||||
|
return CoCaAdapter(CoCa(**self.base_model_kwargs))
|
||||||
|
else:
|
||||||
|
raise AttributeError("No adapter with that name is available.")
|
||||||
|
|
||||||
class DiffusionPriorNetworkConfig(BaseModel):
|
class DiffusionPriorNetworkConfig(BaseModel):
|
||||||
dim: int
|
dim: int
|
||||||
@@ -35,8 +83,12 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
|||||||
normformer: bool = False
|
normformer: bool = False
|
||||||
rotary_emb: bool = True
|
rotary_emb: bool = True
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
kwargs = self.dict()
|
||||||
|
return DiffusionPriorNetwork(**kwargs)
|
||||||
|
|
||||||
class DiffusionPriorConfig(BaseModel):
|
class DiffusionPriorConfig(BaseModel):
|
||||||
# only clip-less diffusion prior config for now
|
clip: AdapterConfig
|
||||||
net: DiffusionPriorNetworkConfig
|
net: DiffusionPriorNetworkConfig
|
||||||
image_embed_dim: int
|
image_embed_dim: int
|
||||||
image_size: int
|
image_size: int
|
||||||
@@ -46,15 +98,52 @@ class DiffusionPriorConfig(BaseModel):
|
|||||||
loss_type: str = 'l2'
|
loss_type: str = 'l2'
|
||||||
predict_x_start: bool = True
|
predict_x_start: bool = True
|
||||||
beta_schedule: str = 'cosine'
|
beta_schedule: str = 'cosine'
|
||||||
|
condition_on_text_encodings: bool = True
|
||||||
def create(self):
|
|
||||||
kwargs = self.dict()
|
|
||||||
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
|
|
||||||
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
kwargs = self.dict()
|
||||||
|
clip = AdapterConfig(**kwargs.pop('clip')).create()
|
||||||
|
diffusion_prior_network = DiffusionPriorNetworkConfig(**kwargs.pop('net')).create()
|
||||||
|
return DiffusionPrior(net = diffusion_prior_network, clip=clip, **kwargs)
|
||||||
|
|
||||||
|
class DiffusionPriorTrainConfig(BaseModel):
|
||||||
|
epochs: int = 1
|
||||||
|
lr: float = 1.1e-4
|
||||||
|
wd: float = 6.02e-2
|
||||||
|
max_grad_norm: float = 0.5
|
||||||
|
use_ema: bool = True
|
||||||
|
ema_beta: float = 0.99
|
||||||
|
amp: bool = False
|
||||||
|
save_every: int = 10000 # what steps to save on
|
||||||
|
|
||||||
|
class DiffusionPriorDataConfig(BaseModel):
|
||||||
|
image_url: str # path to embeddings folder
|
||||||
|
meta_url: str # path to metadata (captions) for images
|
||||||
|
splits: TrainSplitConfig
|
||||||
|
batch_size: int = 64
|
||||||
|
|
||||||
|
class DiffusionPriorLoadConfig(BaseModel):
|
||||||
|
source: str = None
|
||||||
|
resume: bool = False
|
||||||
|
|
||||||
|
class TrainDiffusionPriorConfig(BaseModel):
|
||||||
|
prior: DiffusionPriorConfig
|
||||||
|
data: DiffusionPriorDataConfig
|
||||||
|
train: DiffusionPriorTrainConfig
|
||||||
|
load: DiffusionPriorLoadConfig
|
||||||
|
tracker: TrackerConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_path(cls, json_path):
|
||||||
|
with open(json_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
|
# decoder pydantic classes
|
||||||
|
|
||||||
class UnetConfig(BaseModel):
|
class UnetConfig(BaseModel):
|
||||||
dim: int
|
dim: int
|
||||||
dim_mults: ListOrTuple(int)
|
dim_mults: ListOrTuple(int)
|
||||||
@@ -94,17 +183,6 @@ class DecoderConfig(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
class TrainSplitConfig(BaseModel):
|
|
||||||
train: float = 0.75
|
|
||||||
val: float = 0.15
|
|
||||||
test: float = 0.1
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_all(cls, fields):
|
|
||||||
if sum([*fields.values()]) != 1.:
|
|
||||||
raise ValueError(f'{fields.keys()} must sum to 1.0')
|
|
||||||
return fields
|
|
||||||
|
|
||||||
class DecoderDataConfig(BaseModel):
|
class DecoderDataConfig(BaseModel):
|
||||||
webdataset_base_url: str # path to a webdataset with jpg images
|
webdataset_base_url: str # path to a webdataset with jpg images
|
||||||
embeddings_url: str # path to .npy files with embeddings
|
embeddings_url: str # path to .npy files with embeddings
|
||||||
@@ -160,14 +238,6 @@ class DecoderEvaluateConfig(BaseModel):
|
|||||||
KID: Dict[str, Any] = None
|
KID: Dict[str, Any] = None
|
||||||
LPIPS: Dict[str, Any] = None
|
LPIPS: Dict[str, Any] = None
|
||||||
|
|
||||||
class TrackerConfig(BaseModel):
|
|
||||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
|
||||||
data_path: str = './models' # The path where files will be saved locally
|
|
||||||
init_config: Dict[str, Any] = None
|
|
||||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
|
||||||
wandb_project: str = ''
|
|
||||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
|
||||||
|
|
||||||
class DecoderLoadConfig(BaseModel):
|
class DecoderLoadConfig(BaseModel):
|
||||||
source: str = None # Supports file and wandb
|
source: str = None # Supports file and wandb
|
||||||
run_path: str = '' # Used only if source is wandb
|
run_path: str = '' # Used only if source is wandb
|
||||||
|
|||||||
Reference in New Issue
Block a user