mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-16 22:04:20 +01:00
83 lines
2.2 KiB
Python
83 lines
2.2 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding:utf-8 -*-
|
|
# Power by Zongsheng Yue 2023-10-26 20:20:36
|
|
|
|
import warnings
|
|
warnings.filterwarnings("ignore")
|
|
|
|
import argparse
|
|
from omegaconf import OmegaConf
|
|
|
|
from utils.util_common import get_obj_from_str
|
|
from utils.util_opts import str2bool
|
|
|
|
def get_parser(**parser_kwargs):
|
|
parser = argparse.ArgumentParser(**parser_kwargs)
|
|
parser.add_argument(
|
|
"--save_dir",
|
|
type=str,
|
|
default="./save_dir",
|
|
help="Folder to save the checkpoints and training log",
|
|
)
|
|
parser.add_argument(
|
|
"--resume",
|
|
type=str,
|
|
const=True,
|
|
default="",
|
|
nargs="?",
|
|
help="resume from the save_dir or checkpoint",
|
|
)
|
|
parser.add_argument(
|
|
"--cfg_path",
|
|
type=str,
|
|
default="./configs/sd-turbo-sr-ldis.yaml",
|
|
help="Configs of yaml file",
|
|
)
|
|
parser.add_argument(
|
|
"--ldif",
|
|
type=float,
|
|
default=1.0,
|
|
help="Loss coefficient for diffsuion in latent space",
|
|
)
|
|
parser.add_argument(
|
|
"--llpips",
|
|
type=float,
|
|
default=2.0,
|
|
help="Loss coefficient for latent lpips",
|
|
)
|
|
parser.add_argument(
|
|
"--ldis",
|
|
type=float,
|
|
default=0.1,
|
|
help="Loss coefficient for latent discriminator",
|
|
)
|
|
parser.add_argument(
|
|
"--use_text",
|
|
type=str2bool,
|
|
default='False',
|
|
help="Text Prompt",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
if __name__ == "__main__":
|
|
args = get_parser()
|
|
|
|
configs = OmegaConf.load(args.cfg_path)
|
|
if args.ldif > 0:
|
|
configs.train.loss_coef.ldif = args.ldif
|
|
if args.ldis > 0:
|
|
configs.train.loss_coef.ldis = args.ldis
|
|
if args.llpips > 0:
|
|
configs.train.loss_coef.llpips = args.llpips
|
|
configs.train.use_text = args.use_text
|
|
|
|
# merge args to config
|
|
for key in vars(args):
|
|
if key in ['cfg_path', 'save_dir', 'resume', ]:
|
|
configs[key] = getattr(args, key)
|
|
|
|
trainer = get_obj_from_str(configs.trainer.target)(configs)
|
|
trainer.train()
|