mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-23 03:24:20 +01:00
Prior updates (#211)
* update configs for prior add prior warmup to config update example prior config * update prior trainer & script add deepspeed amp & warmup adopt full accelerator support reload at sample point finish epoch resume code * update tracker save method for prior * helper functions for prior_loader
This commit is contained in:
@@ -145,6 +145,9 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
||||
normformer: bool = False
|
||||
rotary_emb: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
return DiffusionPriorNetwork(**kwargs)
|
||||
@@ -187,23 +190,26 @@ class DiffusionPriorTrainConfig(BaseModel):
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
amp: bool = False
|
||||
save_every: int = 10000 # what steps to save on
|
||||
warmup_steps: int = None # number of warmup steps
|
||||
save_every_seconds: int = 3600 # how often to save
|
||||
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
|
||||
best_validation_loss: float = 1e9 # the current best valudation loss observed
|
||||
current_epoch: int = 0 # the current epoch
|
||||
num_samples_seen: int = 0 # the current number of samples seen
|
||||
random_seed: int = 0 # manual seed for torch
|
||||
|
||||
class DiffusionPriorDataConfig(BaseModel):
|
||||
image_url: str # path to embeddings folder
|
||||
meta_url: str # path to metadata (captions) for images
|
||||
splits: TrainSplitConfig
|
||||
batch_size: int = 64
|
||||
|
||||
class DiffusionPriorLoadConfig(BaseModel):
|
||||
source: str = None
|
||||
resume: bool = False
|
||||
image_url: str # path to embeddings folder
|
||||
meta_url: str # path to metadata (captions) for images
|
||||
splits: TrainSplitConfig # define train, validation, test splits for your dataset
|
||||
batch_size: int # per-gpu batch size used to train the model
|
||||
num_data_points: int = 25e7 # total number of datapoints to train on
|
||||
eval_every_seconds: int = 3600 # validation statistics will be performed this often
|
||||
|
||||
class TrainDiffusionPriorConfig(BaseModel):
|
||||
prior: DiffusionPriorConfig
|
||||
data: DiffusionPriorDataConfig
|
||||
train: DiffusionPriorTrainConfig
|
||||
load: DiffusionPriorLoadConfig
|
||||
tracker: TrackerConfig
|
||||
|
||||
@classmethod
|
||||
@@ -323,12 +329,6 @@ class DecoderEvaluateConfig(BaseModel):
|
||||
KID: Dict[str, Any] = None
|
||||
LPIPS: Dict[str, Any] = None
|
||||
|
||||
class DecoderLoadConfig(BaseModel):
|
||||
source: str = None # Supports file and wandb
|
||||
run_path: str = '' # Used only if source is wandb
|
||||
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
||||
resume: bool = False # If using wandb, whether to resume the run
|
||||
|
||||
class TrainDecoderConfig(BaseModel):
|
||||
decoder: DecoderConfig
|
||||
data: DecoderDataConfig
|
||||
|
||||
Reference in New Issue
Block a user