From f980a9df1460ee53b838b589a758bc49eca4e9b8 Mon Sep 17 00:00:00 2001 From: zsyOAOA Date: Wed, 11 Dec 2024 21:00:33 +0800 Subject: [PATCH] add chopping size options --- inference_invsr.py | 23 ++++++++++++++--------- sampler_invsr.py | 2 -- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/inference_invsr.py b/inference_invsr.py index 9a4eaa6..f43283c 100644 --- a/inference_invsr.py +++ b/inference_invsr.py @@ -37,6 +37,9 @@ def get_parser(**parser_kwargs): parser.add_argument( "--color_fix", type=str, default='', choices=['wavelet', 'ycbcr'], help="Fix the color shift", ) + parser.add_argument( + "--chopping_size", type=int, default=128, help="Chopping size when dealing large images" + ) args = parser.parse_args() return args @@ -77,18 +80,20 @@ def get_configs(args): started_ckpt_name = "noise_predictor_sd_turbo_v5.pth" started_ckpt_dir = "./weights" util_common.mkdir(started_ckpt_dir, delete=False, parents=True) - started_ckpt_path = str(Path(started_ckpt_dir) / started_ckpt_name) - load_file_from_url( - url="https://huggingface.co/OAOA/InvSR/resolve/main/noise_predictor_sd_turbo_v5.pth", - model_dir=started_ckpt_dir, - progress=True, - file_name=started_ckpt_name, - ) - configs.model_start.ckpt_path = started_ckpt_path + started_ckpt_path = Path(started_ckpt_dir) / started_ckpt_name + if not started_ckpt_path.exists(): + load_file_from_url( + url="https://huggingface.co/OAOA/InvSR/resolve/main/noise_predictor_sd_turbo_v5.pth", + model_dir=started_ckpt_dir, + progress=True, + file_name=started_ckpt_name, + ) + configs.model_start.ckpt_path = str(started_ckpt_path) + configs.bs = args.bs configs.tiled_vae = args.tiled_vae configs.color_fix = args.color_fix - configs.bs = args.bs + configs.basesr.chopping.pch_size = args.chopping_size return configs diff --git a/sampler_invsr.py b/sampler_invsr.py index dbe3abd..1417fae 100644 --- a/sampler_invsr.py +++ b/sampler_invsr.py @@ -134,8 +134,6 @@ class InvSamplerSR(BaseSampler): diffusion_sf = self.sd_pipe.transformer.patch_size mod_lq = vae_sf // self.configs.basesr.sf * diffusion_sf idle_pch_size = self.configs.basesr.chopping.pch_size - if ori_h_lq * ori_w_lq >= 512 ** 2: - idle_pch_size = 256 if min(im_cond.shape[-2:]) >= idle_pch_size: pad_h_up = pad_w_left = 0