diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index ecdf994..cecd8c7 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -115,7 +115,7 @@ class TrackerConfig(BaseModel): class AdapterConfig(BaseModel): make: str = "openai" model: str = "ViT-L/14" - base_model_kwargs: Dict[str, Any] = None + base_model_kwargs: Optional[Dict[str, Any]] = None def create(self): if self.make == "openai": @@ -134,8 +134,8 @@ class AdapterConfig(BaseModel): class DiffusionPriorNetworkConfig(BaseModel): dim: int depth: int - max_text_len: int = None - num_timesteps: int = None + max_text_len: Optional[int] = None + num_timesteps: Optional[int] = None num_time_embeds: int = 1 num_image_embeds: int = 1 num_text_embeds: int = 1 @@ -158,7 +158,7 @@ class DiffusionPriorNetworkConfig(BaseModel): return DiffusionPriorNetwork(**kwargs) class DiffusionPriorConfig(BaseModel): - clip: AdapterConfig = None + clip: Optional[AdapterConfig] = None net: DiffusionPriorNetworkConfig image_embed_dim: int image_size: int @@ -195,7 +195,7 @@ class DiffusionPriorTrainConfig(BaseModel): use_ema: bool = True ema_beta: float = 0.99 amp: bool = False - warmup_steps: int = None # number of warmup steps + warmup_steps: Optional[int] = None # number of warmup steps save_every_seconds: int = 3600 # how often to save eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with best_validation_loss: float = 1e9 # the current best valudation loss observed @@ -228,10 +228,10 @@ class TrainDiffusionPriorConfig(BaseModel): class UnetConfig(BaseModel): dim: int dim_mults: ListOrTuple[int] - image_embed_dim: int = None - text_embed_dim: int = None - cond_on_text_encodings: bool = None - cond_dim: int = None + image_embed_dim: Optional[int] = None + text_embed_dim: Optional[int] = None + cond_on_text_encodings: Optional[bool] = None + cond_dim: Optional[int] = None channels: int = 3 self_attn: ListOrTuple[int] attn_dim_head: int = 32 @@ -243,14 +243,14 @@ class UnetConfig(BaseModel): class DecoderConfig(BaseModel): unets: ListOrTuple[UnetConfig] - image_size: int = None + image_size: Optional[int] = None image_sizes: ListOrTuple[int] = None clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided channels: int = 3 timesteps: int = 1000 sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None loss_type: str = 'l2' - beta_schedule: ListOrTuple[str] = None # None means all cosine + beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine learned_variance: SingularOrIterable[bool] = True image_cond_drop_prob: float = 0.1 text_cond_drop_prob: float = 0.5 @@ -320,20 +320,20 @@ class DecoderTrainConfig(BaseModel): n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset cond_scale: Union[float, List[float]] = 1.0 device: str = 'cuda:0' - epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. - validation_samples: int = None # Same as above but for validation. + epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. + validation_samples: Optional[int] = None # Same as above but for validation. save_immediately: bool = False use_ema: bool = True ema_beta: float = 0.999 amp: bool = False - unet_training_mask: ListOrTuple[bool] = None # If None, use all unets + unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets class DecoderEvaluateConfig(BaseModel): n_evaluation_samples: int = 1000 - FID: Dict[str, Any] = None - IS: Dict[str, Any] = None - KID: Dict[str, Any] = None - LPIPS: Dict[str, Any] = None + FID: Optional[Dict[str, Any]] = None + IS: Optional[Dict[str, Any]] = None + KID: Optional[Dict[str, Any]] = None + LPIPS: Optional[Dict[str, Any]] = None class TrainDecoderConfig(BaseModel): decoder: DecoderConfig diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 18ddad0..5fdff84 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.15.2' +__version__ = '1.15.3'