let the pydantic config base model take care of loading configuration from json path

This commit is contained in:
Phil Wang
2022-05-22 14:47:23 -07:00
parent c6629c431a
commit c12e067178
3 changed files with 9 additions and 5 deletions

View File

@@ -1,3 +1,4 @@
import json
from torchvision import transforms as T from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
@@ -111,6 +112,12 @@ class TrainDecoderConfig(BaseModel):
tracker: TrackerConfig tracker: TrackerConfig
load: DecoderLoadConfig load: DecoderLoadConfig
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
@property @property
def img_preproc(self): def img_preproc(self):
def _get_transformation(transformation_name, **kwargs): def _get_transformation(transformation_name, **kwargs):

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.4.1', version = '0.4.2',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -5,7 +5,6 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer from dalle2_pytorch.utils import Timer
import json
import torchvision import torchvision
import torch import torch
from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.fid import FrechetInceptionDistance
@@ -449,9 +448,7 @@ def initialize_training(config):
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file") @click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file): def main(config_file):
print("Recalling config from {}".format(config_file)) print("Recalling config from {}".format(config_file))
with open(config_file) as f: config = TrainDecoderConfig.from_json_path(config_file)
config = json.load(f)
config = TrainDecoderConfig(**config)
initialize_training(config) initialize_training(config)