From f3d7e226ba28879d43bc1642158078045c128df1 Mon Sep 17 00:00:00 2001 From: Aidan Dempster Date: Fri, 22 Jul 2022 16:16:29 -0400 Subject: [PATCH] Changed types to be generic instead of functions (#215) This allows pylance to do proper type hinting and makes developing extensions to the package much easier --- dalle2_pytorch/train_configs.py | 34 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 307f011..ae6407f 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -1,7 +1,7 @@ import json from torchvision import transforms as T from pydantic import BaseModel, validator, root_validator -from typing import List, Iterable, Optional, Union, Tuple, Dict, Any +from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar from x_clip import CLIP as XCLIP from coca_pytorch import CoCa @@ -25,11 +25,9 @@ def exists(val): def default(val, d): return val if exists(val) else d -def ListOrTuple(inner_type): - return Union[List[inner_type], Tuple[inner_type]] - -def SingularOrIterable(inner_type): - return Union[inner_type, ListOrTuple(inner_type)] +InnerType = TypeVar('InnerType') +ListOrTuple = Union[List[InnerType], Tuple[InnerType]] +SingularOrIterable = Union[InnerType, ListOrTuple[InnerType]] # general pydantic classes @@ -222,13 +220,13 @@ class TrainDiffusionPriorConfig(BaseModel): class UnetConfig(BaseModel): dim: int - dim_mults: ListOrTuple(int) + dim_mults: ListOrTuple[int] image_embed_dim: int = None text_embed_dim: int = None cond_on_text_encodings: bool = None cond_dim: int = None channels: int = 3 - self_attn: ListOrTuple(int) + self_attn: ListOrTuple[int] attn_dim_head: int = 32 attn_heads: int = 16 init_cross_embed: bool = True @@ -237,16 +235,16 @@ class UnetConfig(BaseModel): extra = "allow" class DecoderConfig(BaseModel): - unets: ListOrTuple(UnetConfig) + unets: ListOrTuple[UnetConfig] image_size: int = None - image_sizes: ListOrTuple(int) = None + image_sizes: ListOrTuple[int] = None clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided channels: int = 3 timesteps: int = 1000 - sample_timesteps: Optional[SingularOrIterable(int)] = None + sample_timesteps: Optional[SingularOrIterable[int]] = None loss_type: str = 'l2' - beta_schedule: ListOrTuple(str) = 'cosine' - learned_variance: bool = True + beta_schedule: ListOrTuple[str] = None # None means all cosine + learned_variance: SingularOrIterable[bool] = True image_cond_drop_prob: float = 0.1 text_cond_drop_prob: float = 0.5 @@ -305,11 +303,11 @@ class DecoderDataConfig(BaseModel): class DecoderTrainConfig(BaseModel): epochs: int = 20 - lr: SingularOrIterable(float) = 1e-4 - wd: SingularOrIterable(float) = 0.01 - warmup_steps: Optional[SingularOrIterable(int)] = None + lr: SingularOrIterable[float] = 1e-4 + wd: SingularOrIterable[float] = 0.01 + warmup_steps: Optional[SingularOrIterable[int]] = None find_unused_parameters: bool = True - max_grad_norm: SingularOrIterable(float) = 0.5 + max_grad_norm: SingularOrIterable[float] = 0.5 save_every_n_samples: int = 100000 n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset cond_scale: Union[float, List[float]] = 1.0 @@ -320,7 +318,7 @@ class DecoderTrainConfig(BaseModel): use_ema: bool = True ema_beta: float = 0.999 amp: bool = False - unet_training_mask: ListOrTuple(bool) = None # If None, use all unets + unet_training_mask: ListOrTuple[bool] = None # If None, use all unets class DecoderEvaluateConfig(BaseModel): n_evaluation_samples: int = 1000