mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
let the pydantic config base model take care of loading configuration from json path
This commit is contained in:
@@ -5,7 +5,6 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer
|
||||
|
||||
import json
|
||||
import torchvision
|
||||
import torch
|
||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
@@ -449,9 +448,7 @@ def initialize_training(config):
|
||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||
def main(config_file):
|
||||
print("Recalling config from {}".format(config_file))
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
config = TrainDecoderConfig(**config)
|
||||
config = TrainDecoderConfig.from_json_path(config_file)
|
||||
initialize_training(config)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user