mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
let the pydantic config base model take care of loading configuration from json path
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from torchvision import transforms as T
|
from torchvision import transforms as T
|
||||||
from pydantic import BaseModel, validator, root_validator
|
from pydantic import BaseModel, validator, root_validator
|
||||||
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
|
||||||
@@ -111,6 +112,12 @@ class TrainDecoderConfig(BaseModel):
|
|||||||
tracker: TrackerConfig
|
tracker: TrackerConfig
|
||||||
load: DecoderLoadConfig
|
load: DecoderLoadConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_path(cls, json_path):
|
||||||
|
with open(json_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def img_preproc(self):
|
def img_preproc(self):
|
||||||
def _get_transformation(transformation_name, **kwargs):
|
def _get_transformation(transformation_name, **kwargs):
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.4.1',
|
version = '0.4.2',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
|||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer
|
from dalle2_pytorch.utils import Timer
|
||||||
|
|
||||||
import json
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||||
@@ -449,9 +448,7 @@ def initialize_training(config):
|
|||||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||||
def main(config_file):
|
def main(config_file):
|
||||||
print("Recalling config from {}".format(config_file))
|
print("Recalling config from {}".format(config_file))
|
||||||
with open(config_file) as f:
|
config = TrainDecoderConfig.from_json_path(config_file)
|
||||||
config = json.load(f)
|
|
||||||
config = TrainDecoderConfig(**config)
|
|
||||||
initialize_training(config)
|
initialize_training(config)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user