add chopping bs in argparse

This commit is contained in:
zsyOAOA
2024-12-13 15:20:00 +08:00
parent f980a9df14
commit f138fdff1b
3 changed files with 9 additions and 2 deletions

View File

@@ -7,7 +7,6 @@ basesr:
chopping: # for latent diffusion
pch_size: 128
weight_type: Gaussian
extra_bs: 16 # 16 ----> 26G memory
# VAE settings
tiled_vae: True
@@ -22,6 +21,9 @@ cfg_scale: 1.0
# sampling settings
start_timesteps: 200
# color fixing
color_fix: ~
# Stable Diffusion
base_model: sd-turbo
sd_pipe:

View File

@@ -20,6 +20,7 @@ def get_parser(**parser_kwargs):
parser.add_argument("-i", "--in_path", type=str, default="", help="Input path")
parser.add_argument("-o", "--out_path", type=str, default="", help="Output path")
parser.add_argument("--bs", type=int, default=1, help="Batchsize for loading image")
parser.add_argument("--chopping_bs", type=int, default=8, help="Batchsize for chopped patch")
parser.add_argument("-t", "--timesteps", type=int, nargs="+", help="The inversed timesteps")
parser.add_argument("-n", "--num_steps", type=int, default=1, help="Number of inference steps")
parser.add_argument(
@@ -94,6 +95,10 @@ def get_configs(args):
configs.tiled_vae = args.tiled_vae
configs.color_fix = args.color_fix
configs.basesr.chopping.pch_size = args.chopping_size
if args.bs > 1:
configs.basesr.chopping.extra_bs = 1
else:
configs.basesr.chopping.extra_bs = args.chopping_bs
return configs

View File

@@ -173,7 +173,7 @@ class InvSamplerSR(BaseSampler):
stride= int(idle_pch_size * 0.50),
sf=self.configs.basesr.sf,
weight_type=self.configs.basesr.chopping.weight_type,
extra_bs=1 if self.configs.bs > 1 else self.configs.basesr.chopping.extra_bs,
extra_bs=self.configs.basesr.chopping.extra_bs,
)
for im_lq_pch, index_infos in im_spliter:
target_size = (