mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 06:14:22 +01:00
111 lines
3.9 KiB
Python
111 lines
3.9 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding:utf-8 -*-
|
|
# Power by Zongsheng Yue 2023-03-11 17:17:41
|
|
|
|
import warnings
|
|
warnings.filterwarnings("ignore")
|
|
|
|
import argparse
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from omegaconf import OmegaConf
|
|
from sampler_invsr import InvSamplerSR
|
|
|
|
from utils import util_common
|
|
from utils.util_opts import str2bool
|
|
from basicsr.utils.download_util import load_file_from_url
|
|
|
|
def get_parser(**parser_kwargs):
|
|
parser = argparse.ArgumentParser(**parser_kwargs)
|
|
parser.add_argument("-i", "--in_path", type=str, default="", help="Input path")
|
|
parser.add_argument("-o", "--out_path", type=str, default="", help="Output path")
|
|
parser.add_argument("--bs", type=int, default=1, help="Batchsize for loading image")
|
|
parser.add_argument("-t", "--timesteps", type=int, nargs="+", help="The inversed timesteps")
|
|
parser.add_argument("-n", "--num_steps", type=int, default=1, help="Number of inference steps")
|
|
parser.add_argument(
|
|
"--cfg_path", type=str, default="./configs/sample-sd-turbo.yaml", help="Configuration path.",
|
|
)
|
|
parser.add_argument(
|
|
"--sd_path", type=str, default="", help="Path for Stable Diffusion Model",
|
|
)
|
|
parser.add_argument(
|
|
"--started_ckpt_path", type=str, default="", help="Checkpoint path for noise predictor"
|
|
)
|
|
parser.add_argument(
|
|
"--tiled_vae", type=str2bool, default='true', help="Enabled tiled VAE.",
|
|
)
|
|
parser.add_argument(
|
|
"--color_fix", type=str, default='', choices=['wavelet', 'ycbcr'], help="Fix the color shift",
|
|
)
|
|
parser.add_argument(
|
|
"--chopping_size", type=int, default=128, help="Chopping size when dealing large images"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
def get_configs(args):
|
|
configs = OmegaConf.load(args.cfg_path)
|
|
|
|
if args.timesteps is not None:
|
|
assert len(args.timesteps) == args.num_steps
|
|
configs.timesteps = sorted(args.timesteps, reverse=True)
|
|
else:
|
|
if args.num_steps == 1:
|
|
configs.timesteps = [200,]
|
|
elif args.num_steps == 2:
|
|
configs.timesteps = [200, 100]
|
|
elif args.num_steps == 3:
|
|
configs.timesteps = [200, 100, 50]
|
|
elif args.num_steps == 4:
|
|
configs.timesteps = [200, 150, 100, 50]
|
|
elif args.num_steps == 5:
|
|
configs.timesteps = [250, 200, 150, 100, 50]
|
|
else:
|
|
assert args.num_steps <= 250
|
|
configs.timesteps = np.linspace(
|
|
start=args.started_step, stop=0, num=args.num_steps, endpoint=False, dtype=np.int64()
|
|
).tolist()
|
|
print(f'Setting timesteps for inference: {configs.timesteps}')
|
|
|
|
# path to save Stable Diffusion
|
|
sd_path = args.sd_path if args.sd_path else "./weights"
|
|
util_common.mkdir(sd_path, delete=False, parents=True)
|
|
configs.sd_pipe.params.cache_dir = sd_path
|
|
|
|
# path to save noise predictor
|
|
if args.started_ckpt_path:
|
|
started_ckpt_path = args.started_ckpt_path
|
|
else:
|
|
started_ckpt_name = "noise_predictor_sd_turbo_v5.pth"
|
|
started_ckpt_dir = "./weights"
|
|
util_common.mkdir(started_ckpt_dir, delete=False, parents=True)
|
|
started_ckpt_path = Path(started_ckpt_dir) / started_ckpt_name
|
|
if not started_ckpt_path.exists():
|
|
load_file_from_url(
|
|
url="https://huggingface.co/OAOA/InvSR/resolve/main/noise_predictor_sd_turbo_v5.pth",
|
|
model_dir=started_ckpt_dir,
|
|
progress=True,
|
|
file_name=started_ckpt_name,
|
|
)
|
|
configs.model_start.ckpt_path = str(started_ckpt_path)
|
|
|
|
configs.bs = args.bs
|
|
configs.tiled_vae = args.tiled_vae
|
|
configs.color_fix = args.color_fix
|
|
configs.basesr.chopping.pch_size = args.chopping_size
|
|
|
|
return configs
|
|
|
|
def main():
|
|
args = get_parser()
|
|
|
|
configs = get_configs(args)
|
|
|
|
sampler = InvSamplerSR(configs)
|
|
|
|
sampler.inference(args.in_path, out_path=args.out_path, bs=args.bs)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|