mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 06:14:22 +01:00
92 lines
3.3 KiB
Python
92 lines
3.3 KiB
Python
# Prediction interface for Cog ⚙️
|
|
# https://cog.run/python
|
|
|
|
|
|
import shutil, os
|
|
from omegaconf import OmegaConf
|
|
from cog import BasePredictor, Input, Path
|
|
|
|
import numpy as np
|
|
from utils import util_common
|
|
from sampler_invsr import InvSamplerSR
|
|
from basicsr.utils.download_util import load_file_from_url
|
|
|
|
class Predictor(BasePredictor):
|
|
def setup(self) -> None:
|
|
self.configs = OmegaConf.load("./configs/sample-sd-turbo.yaml")
|
|
|
|
def set_configs(self, num_steps=1, chopping_size=128, seed=12345):
|
|
if num_steps == 1:
|
|
self.configs.timesteps = [200,]
|
|
elif num_steps == 2:
|
|
self.configs.timesteps = [200, 100]
|
|
elif num_steps == 3:
|
|
self.configs.timesteps = [200, 100, 50]
|
|
elif num_steps == 4:
|
|
self.configs.timesteps = [200, 150, 100, 50]
|
|
elif num_steps == 5:
|
|
self.configs.timesteps = [250, 200, 150, 100, 50]
|
|
else:
|
|
assert num_steps <= 250
|
|
self.configs.timesteps = np.linspace(
|
|
start=250, stop=0, num=num_steps, endpoint=False, dtype=np.int64()
|
|
).tolist()
|
|
print(f'Setting timesteps for inference: {self.configs.timesteps}')
|
|
|
|
# path to save Stable Diffusion
|
|
sd_path = "./weights"
|
|
util_common.mkdir(sd_path, delete=False, parents=True)
|
|
self.configs.sd_pipe.params.cache_dir = sd_path
|
|
|
|
# path to save noise predictor
|
|
started_ckpt_name = "noise_predictor_sd_turbo_v5.pth"
|
|
started_ckpt_dir = "./weights"
|
|
util_common.mkdir(started_ckpt_dir, delete=False, parents=True)
|
|
started_ckpt_path = os.path.join(started_ckpt_dir, started_ckpt_name)
|
|
if not os.path.exists(started_ckpt_path):
|
|
load_file_from_url(
|
|
url="https://huggingface.co/OAOA/InvSR/resolve/main/noise_predictor_sd_turbo_v5.pth",
|
|
model_dir=started_ckpt_dir,
|
|
progress=True,
|
|
file_name=started_ckpt_name,
|
|
)
|
|
self.configs.model_start.ckpt_path = started_ckpt_path
|
|
|
|
self.configs.bs = 1
|
|
self.configs.seed = 12345
|
|
self.configs.basesr.chopping.pch_size = chopping_size
|
|
if chopping_size == 128:
|
|
self.configs.basesr.chopping.extra_bs = 4
|
|
elif chopping_size == 256:
|
|
self.configs.basesr.chopping.extra_bs = 2
|
|
else:
|
|
self.configs.basesr.chopping.extra_bs = 1
|
|
|
|
def predict(
|
|
self,
|
|
in_path: Path = Input(description="Input low-quality image"),
|
|
num_steps: int = Input(
|
|
choices=[1,2,3,4,5], description="Number of sampling steps.", default=1
|
|
),
|
|
chopping_size: int = Input(
|
|
choices=[128, 256, 512], description="Chopping resolution", default=128
|
|
),
|
|
seed: int = Input(
|
|
description="Random seed. Leave blank to randomize the seed.", default=12345
|
|
),
|
|
) -> Path:
|
|
# setting configurations
|
|
self.set_configs(num_steps, chopping_size, seed)
|
|
|
|
sampler = InvSamplerSR(self.configs)
|
|
|
|
out_dir = 'invsr_output'
|
|
if os.path.exists(out_dir):
|
|
shutil.rmtree(out_dir)
|
|
sampler.inference(in_path, out_path=out_dir, bs=1)
|
|
|
|
out = "/tmp/out.png"
|
|
shutil.copy(os.path.join(out_dir, os.listdir(out_dir)[0]), out)
|
|
|
|
return Path(out)
|