From d7a0a2ce4be76c850a44ad962b99108cb20d331c Mon Sep 17 00:00:00 2001 From: zion <51308183+nousr@users.noreply.github.com> Date: Wed, 25 May 2022 09:06:50 -0700 Subject: [PATCH] add more support for configuring prior (#113) --- configs/train_prior_config.example.json | 70 +++++++++++++ dalle2_pytorch/train_configs.py | 124 ++++++++++++++++++------ 2 files changed, 167 insertions(+), 27 deletions(-) create mode 100644 configs/train_prior_config.example.json diff --git a/configs/train_prior_config.example.json b/configs/train_prior_config.example.json new file mode 100644 index 0000000..151ca28 --- /dev/null +++ b/configs/train_prior_config.example.json @@ -0,0 +1,70 @@ +{ + "prior": { + "clip": { + "make": "x-clip", + "model": "ViT-L/14", + "base_model_kwargs": { + "dim_text": 768, + "dim_image": 768, + "dim_latent": 768 + } + }, + "net": { + "dim": 768, + "depth": 12, + "num_timesteps": 1000, + "num_time_embeds": 1, + "num_image_embeds": 1, + "num_text_embeds": 1, + "dim_head": 64, + "heads": 12, + "ff_mult": 4, + "norm_out": true, + "attn_dropout": 0.0, + "ff_dropout": 0.0, + "final_proj": true, + "normformer": true, + "rotary_emb": true + }, + "image_embed_dim": 768, + "image_size": 224, + "image_channels": 3, + "timesteps": 1000, + "cond_drop_prob": 0.1, + "loss_type": "l2", + "predict_x_start": true, + "beta_schedule": "cosine", + "condition_on_text_encodings": true + }, + "data": { + "image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/", + "text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/", + "meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/", + "batch_size": 256, + "splits": { + "train": 0.9, + "val": 1e-7, + "test": 0.0999999 + } + }, + "train": { + "epochs": 1, + "lr": 1.1e-4, + "wd": 6.02e-2, + "max_grad_norm": 0.5, + "use_ema": true, + "amp": false, + "save_every": 10000 + }, + "load": { + "source": null, + "resume": false + }, + "tracker": { + "tracker_type": "wandb", + "data_path": "./prior_checkpoints", + "wandb_entity": "laion", + "wandb_project": "diffusion-prior", + "verbose": true + } +} diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 695f5ee..36be714 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -3,7 +3,18 @@ from torchvision import transforms as T from pydantic import BaseModel, validator, root_validator from typing import List, Iterable, Optional, Union, Tuple, Dict, Any -from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork +from x_clip import CLIP as XCLIP +from coca_pytorch import CoCa + +from dalle2_pytorch.dalle2_pytorch import ( + CoCaAdapter, + OpenAIClipAdapter, + Unet, + Decoder, + DiffusionPrior, + DiffusionPriorNetwork, + XClipAdapter, +) # helper functions @@ -16,7 +27,44 @@ def default(val, d): def ListOrTuple(inner_type): return Union[List[inner_type], Tuple[inner_type]] -# pydantic classes +# general pydantic classes + +class TrainSplitConfig(BaseModel): + train: float = 0.75 + val: float = 0.15 + test: float = 0.1 + + @root_validator + def validate_all(cls, fields): + actual_sum = sum([*fields.values()]) + if actual_sum != 1.: + raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}') + return fields + +class TrackerConfig(BaseModel): + tracker_type: str = 'console' # Decoder currently supports console and wandb + data_path: str = './models' # The path where files will be saved locally + init_config: Dict[str, Any] = None + wandb_entity: str = '' # Only needs to be set if tracker_type is wandb + wandb_project: str = '' + verbose: bool = False # Whether to print console logging for non-console trackers + +# diffusion prior pydantic classes + +class AdapterConfig(BaseModel): + make: str = "openai" + model: str = "ViT-L/14" + base_model_kwargs: Dict[str, Any] = None + + def create(self): + if self.make == "openai": + return OpenAIClipAdapter(self.model) + elif self.make == "x-clip": + return XClipAdapter(XCLIP(**self.base_model_kwargs)) + elif self.make == "coca": + return CoCaAdapter(CoCa(**self.base_model_kwargs)) + else: + raise AttributeError("No adapter with that name is available.") class DiffusionPriorNetworkConfig(BaseModel): dim: int @@ -35,8 +83,12 @@ class DiffusionPriorNetworkConfig(BaseModel): normformer: bool = False rotary_emb: bool = True + def create(self): + kwargs = self.dict() + return DiffusionPriorNetwork(**kwargs) + class DiffusionPriorConfig(BaseModel): - # only clip-less diffusion prior config for now + clip: AdapterConfig net: DiffusionPriorNetworkConfig image_embed_dim: int image_size: int @@ -46,15 +98,52 @@ class DiffusionPriorConfig(BaseModel): loss_type: str = 'l2' predict_x_start: bool = True beta_schedule: str = 'cosine' - - def create(self): - kwargs = self.dict() - diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net')) - return DiffusionPrior(net = diffusion_prior_network, **kwargs) + condition_on_text_encodings: bool = True class Config: extra = "allow" + def create(self): + kwargs = self.dict() + clip = AdapterConfig(**kwargs.pop('clip')).create() + diffusion_prior_network = DiffusionPriorNetworkConfig(**kwargs.pop('net')).create() + return DiffusionPrior(net = diffusion_prior_network, clip=clip, **kwargs) + +class DiffusionPriorTrainConfig(BaseModel): + epochs: int = 1 + lr: float = 1.1e-4 + wd: float = 6.02e-2 + max_grad_norm: float = 0.5 + use_ema: bool = True + ema_beta: float = 0.99 + amp: bool = False + save_every: int = 10000 # what steps to save on + +class DiffusionPriorDataConfig(BaseModel): + image_url: str # path to embeddings folder + meta_url: str # path to metadata (captions) for images + splits: TrainSplitConfig + batch_size: int = 64 + +class DiffusionPriorLoadConfig(BaseModel): + source: str = None + resume: bool = False + +class TrainDiffusionPriorConfig(BaseModel): + prior: DiffusionPriorConfig + data: DiffusionPriorDataConfig + train: DiffusionPriorTrainConfig + load: DiffusionPriorLoadConfig + tracker: TrackerConfig + + @classmethod + def from_json_path(cls, json_path): + with open(json_path) as f: + config = json.load(f) + return cls(**config) + +# decoder pydantic classes + class UnetConfig(BaseModel): dim: int dim_mults: ListOrTuple(int) @@ -94,17 +183,6 @@ class DecoderConfig(BaseModel): class Config: extra = "allow" -class TrainSplitConfig(BaseModel): - train: float = 0.75 - val: float = 0.15 - test: float = 0.1 - - @root_validator - def validate_all(cls, fields): - if sum([*fields.values()]) != 1.: - raise ValueError(f'{fields.keys()} must sum to 1.0') - return fields - class DecoderDataConfig(BaseModel): webdataset_base_url: str # path to a webdataset with jpg images embeddings_url: str # path to .npy files with embeddings @@ -160,14 +238,6 @@ class DecoderEvaluateConfig(BaseModel): KID: Dict[str, Any] = None LPIPS: Dict[str, Any] = None -class TrackerConfig(BaseModel): - tracker_type: str = 'console' # Decoder currently supports console and wandb - data_path: str = './models' # The path where files will be saved locally - init_config: Dict[str, Any] = None - wandb_entity: str = '' # Only needs to be set if tracker_type is wandb - wandb_project: str = '' - verbose: bool = False # Whether to print console logging for non-console trackers - class DecoderLoadConfig(BaseModel): source: str = None # Supports file and wandb run_path: str = '' # Used only if source is wandb