first commit

This commit is contained in:
zsyOAOA
2024-12-11 18:46:36 +08:00
parent 9e65255d34
commit 27f2eb7dc3
847 changed files with 377076 additions and 2 deletions

82
main.py Normal file
View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2023-10-26 20:20:36
import warnings
warnings.filterwarnings("ignore")
import argparse
from omegaconf import OmegaConf
from utils.util_common import get_obj_from_str
from utils.util_opts import str2bool
def get_parser(**parser_kwargs):
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
"--save_dir",
type=str,
default="./save_dir",
help="Folder to save the checkpoints and training log",
)
parser.add_argument(
"--resume",
type=str,
const=True,
default="",
nargs="?",
help="resume from the save_dir or checkpoint",
)
parser.add_argument(
"--cfg_path",
type=str,
default="./configs/sd-turbo-sr-ldis.yaml",
help="Configs of yaml file",
)
parser.add_argument(
"--ldif",
type=float,
default=1.0,
help="Loss coefficient for diffsuion in latent space",
)
parser.add_argument(
"--llpips",
type=float,
default=2.0,
help="Loss coefficient for latent lpips",
)
parser.add_argument(
"--ldis",
type=float,
default=0.1,
help="Loss coefficient for latent discriminator",
)
parser.add_argument(
"--use_text",
type=str2bool,
default='False',
help="Text Prompt",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_parser()
configs = OmegaConf.load(args.cfg_path)
if args.ldif > 0:
configs.train.loss_coef.ldif = args.ldif
if args.ldis > 0:
configs.train.loss_coef.ldis = args.ldis
if args.llpips > 0:
configs.train.loss_coef.llpips = args.llpips
configs.train.use_text = args.use_text
# merge args to config
for key in vars(args):
if key in ['cfg_path', 'save_dir', 'resume', ]:
configs[key] = getattr(args, key)
trainer = get_obj_from_str(configs.trainer.target)(configs)
trainer.train()