add chopping size options

This commit is contained in:
zsyOAOA
2024-12-11 21:00:33 +08:00
parent 8d7e434e4f
commit f980a9df14
2 changed files with 14 additions and 11 deletions

View File

@@ -37,6 +37,9 @@ def get_parser(**parser_kwargs):
parser.add_argument( parser.add_argument(
"--color_fix", type=str, default='', choices=['wavelet', 'ycbcr'], help="Fix the color shift", "--color_fix", type=str, default='', choices=['wavelet', 'ycbcr'], help="Fix the color shift",
) )
parser.add_argument(
"--chopping_size", type=int, default=128, help="Chopping size when dealing large images"
)
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -77,18 +80,20 @@ def get_configs(args):
started_ckpt_name = "noise_predictor_sd_turbo_v5.pth" started_ckpt_name = "noise_predictor_sd_turbo_v5.pth"
started_ckpt_dir = "./weights" started_ckpt_dir = "./weights"
util_common.mkdir(started_ckpt_dir, delete=False, parents=True) util_common.mkdir(started_ckpt_dir, delete=False, parents=True)
started_ckpt_path = str(Path(started_ckpt_dir) / started_ckpt_name) started_ckpt_path = Path(started_ckpt_dir) / started_ckpt_name
if not started_ckpt_path.exists():
load_file_from_url( load_file_from_url(
url="https://huggingface.co/OAOA/InvSR/resolve/main/noise_predictor_sd_turbo_v5.pth", url="https://huggingface.co/OAOA/InvSR/resolve/main/noise_predictor_sd_turbo_v5.pth",
model_dir=started_ckpt_dir, model_dir=started_ckpt_dir,
progress=True, progress=True,
file_name=started_ckpt_name, file_name=started_ckpt_name,
) )
configs.model_start.ckpt_path = started_ckpt_path configs.model_start.ckpt_path = str(started_ckpt_path)
configs.bs = args.bs
configs.tiled_vae = args.tiled_vae configs.tiled_vae = args.tiled_vae
configs.color_fix = args.color_fix configs.color_fix = args.color_fix
configs.bs = args.bs configs.basesr.chopping.pch_size = args.chopping_size
return configs return configs

View File

@@ -134,8 +134,6 @@ class InvSamplerSR(BaseSampler):
diffusion_sf = self.sd_pipe.transformer.patch_size diffusion_sf = self.sd_pipe.transformer.patch_size
mod_lq = vae_sf // self.configs.basesr.sf * diffusion_sf mod_lq = vae_sf // self.configs.basesr.sf * diffusion_sf
idle_pch_size = self.configs.basesr.chopping.pch_size idle_pch_size = self.configs.basesr.chopping.pch_size
if ori_h_lq * ori_w_lq >= 512 ** 2:
idle_pch_size = 256
if min(im_cond.shape[-2:]) >= idle_pch_size: if min(im_cond.shape[-2:]) >= idle_pch_size:
pad_h_up = pad_w_left = 0 pad_h_up = pad_w_left = 0