From c12e06717857ff60dcf79d74cad9418ee02e277a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 22 May 2022 14:47:23 -0700 Subject: [PATCH] let the pydantic config base model take care of loading configuration from json path --- dalle2_pytorch/train_configs.py | 7 +++++++ setup.py | 2 +- train_decoder.py | 5 +---- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index d6b9edb..a57542c 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -1,3 +1,4 @@ +import json from torchvision import transforms as T from pydantic import BaseModel, validator, root_validator from typing import List, Iterable, Optional, Union, Tuple, Dict, Any @@ -111,6 +112,12 @@ class TrainDecoderConfig(BaseModel): tracker: TrackerConfig load: DecoderLoadConfig + @classmethod + def from_json_path(cls, json_path): + with open(json_path) as f: + config = json.load(f) + return cls(**config) + @property def img_preproc(self): def _get_transformation(transformation_name, **kwargs): diff --git a/setup.py b/setup.py index 3ce651a..563f1a2 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.4.1', + version = '0.4.2', license='MIT', description = 'DALL-E 2', author = 'Phil Wang', diff --git a/train_decoder.py b/train_decoder.py index 9e872e5..e2dfebe 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -5,7 +5,6 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.utils import Timer -import json import torchvision import torch from torchmetrics.image.fid import FrechetInceptionDistance @@ -449,9 +448,7 @@ def initialize_training(config): @click.option("--config_file", default="./train_decoder_config.json", help="Path to config file") def main(config_file): print("Recalling config from {}".format(config_file)) - with open(config_file) as f: - config = json.load(f) - config = TrainDecoderConfig(**config) + config = TrainDecoderConfig.from_json_path(config_file) initialize_training(config)