mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 06:14:22 +01:00
add chopping size options
This commit is contained in:
@@ -37,6 +37,9 @@ def get_parser(**parser_kwargs):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--color_fix", type=str, default='', choices=['wavelet', 'ycbcr'], help="Fix the color shift",
|
"--color_fix", type=str, default='', choices=['wavelet', 'ycbcr'], help="Fix the color shift",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chopping_size", type=int, default=128, help="Chopping size when dealing large images"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return args
|
return args
|
||||||
@@ -77,18 +80,20 @@ def get_configs(args):
|
|||||||
started_ckpt_name = "noise_predictor_sd_turbo_v5.pth"
|
started_ckpt_name = "noise_predictor_sd_turbo_v5.pth"
|
||||||
started_ckpt_dir = "./weights"
|
started_ckpt_dir = "./weights"
|
||||||
util_common.mkdir(started_ckpt_dir, delete=False, parents=True)
|
util_common.mkdir(started_ckpt_dir, delete=False, parents=True)
|
||||||
started_ckpt_path = str(Path(started_ckpt_dir) / started_ckpt_name)
|
started_ckpt_path = Path(started_ckpt_dir) / started_ckpt_name
|
||||||
load_file_from_url(
|
if not started_ckpt_path.exists():
|
||||||
url="https://huggingface.co/OAOA/InvSR/resolve/main/noise_predictor_sd_turbo_v5.pth",
|
load_file_from_url(
|
||||||
model_dir=started_ckpt_dir,
|
url="https://huggingface.co/OAOA/InvSR/resolve/main/noise_predictor_sd_turbo_v5.pth",
|
||||||
progress=True,
|
model_dir=started_ckpt_dir,
|
||||||
file_name=started_ckpt_name,
|
progress=True,
|
||||||
)
|
file_name=started_ckpt_name,
|
||||||
configs.model_start.ckpt_path = started_ckpt_path
|
)
|
||||||
|
configs.model_start.ckpt_path = str(started_ckpt_path)
|
||||||
|
|
||||||
|
configs.bs = args.bs
|
||||||
configs.tiled_vae = args.tiled_vae
|
configs.tiled_vae = args.tiled_vae
|
||||||
configs.color_fix = args.color_fix
|
configs.color_fix = args.color_fix
|
||||||
configs.bs = args.bs
|
configs.basesr.chopping.pch_size = args.chopping_size
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|||||||
@@ -134,8 +134,6 @@ class InvSamplerSR(BaseSampler):
|
|||||||
diffusion_sf = self.sd_pipe.transformer.patch_size
|
diffusion_sf = self.sd_pipe.transformer.patch_size
|
||||||
mod_lq = vae_sf // self.configs.basesr.sf * diffusion_sf
|
mod_lq = vae_sf // self.configs.basesr.sf * diffusion_sf
|
||||||
idle_pch_size = self.configs.basesr.chopping.pch_size
|
idle_pch_size = self.configs.basesr.chopping.pch_size
|
||||||
if ori_h_lq * ori_w_lq >= 512 ** 2:
|
|
||||||
idle_pch_size = 256
|
|
||||||
|
|
||||||
if min(im_cond.shape[-2:]) >= idle_pch_size:
|
if min(im_cond.shape[-2:]) >= idle_pch_size:
|
||||||
pad_h_up = pad_w_left = 0
|
pad_h_up = pad_w_left = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user