mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 06:14:22 +01:00
add replicate demo
This commit is contained in:
91
predict.py
Normal file
91
predict.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Prediction interface for Cog ⚙️
|
||||
# https://cog.run/python
|
||||
|
||||
|
||||
import shutil, os
|
||||
from omegaconf import OmegaConf
|
||||
from cog import BasePredictor, Input, Path
|
||||
|
||||
import numpy as np
|
||||
from utils import util_common
|
||||
from sampler_invsr import InvSamplerSR
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self) -> None:
|
||||
self.configs = OmegaConf.load("./configs/sample-sd-turbo.yaml")
|
||||
|
||||
def set_configs(self, num_steps=1, chopping_size=128, seed=12345):
|
||||
if num_steps == 1:
|
||||
self.configs.timesteps = [200,]
|
||||
elif num_steps == 2:
|
||||
self.configs.timesteps = [200, 100]
|
||||
elif num_steps == 3:
|
||||
self.configs.timesteps = [200, 100, 50]
|
||||
elif num_steps == 4:
|
||||
self.configs.timesteps = [200, 150, 100, 50]
|
||||
elif num_steps == 5:
|
||||
self.configs.timesteps = [250, 200, 150, 100, 50]
|
||||
else:
|
||||
assert num_steps <= 250
|
||||
self.configs.timesteps = np.linspace(
|
||||
start=250, stop=0, num=num_steps, endpoint=False, dtype=np.int64()
|
||||
).tolist()
|
||||
print(f'Setting timesteps for inference: {self.configs.timesteps}')
|
||||
|
||||
# path to save Stable Diffusion
|
||||
sd_path = "./weights"
|
||||
util_common.mkdir(sd_path, delete=False, parents=True)
|
||||
self.configs.sd_pipe.params.cache_dir = sd_path
|
||||
|
||||
# path to save noise predictor
|
||||
started_ckpt_name = "noise_predictor_sd_turbo_v5.pth"
|
||||
started_ckpt_dir = "./weights"
|
||||
util_common.mkdir(started_ckpt_dir, delete=False, parents=True)
|
||||
started_ckpt_path = os.path.join(started_ckpt_dir, started_ckpt_name)
|
||||
if not os.path.exists(started_ckpt_path):
|
||||
load_file_from_url(
|
||||
url="https://huggingface.co/OAOA/InvSR/resolve/main/noise_predictor_sd_turbo_v5.pth",
|
||||
model_dir=started_ckpt_dir,
|
||||
progress=True,
|
||||
file_name=started_ckpt_name,
|
||||
)
|
||||
self.configs.model_start.ckpt_path = started_ckpt_path
|
||||
|
||||
self.configs.bs = 1
|
||||
self.configs.seed = 12345
|
||||
self.configs.basesr.chopping.pch_size = chopping_size
|
||||
if chopping_size == 128:
|
||||
self.configs.basesr.chopping.extra_bs = 4
|
||||
elif chopping_size == 256:
|
||||
self.configs.basesr.chopping.extra_bs = 2
|
||||
else:
|
||||
self.configs.basesr.chopping.extra_bs = 1
|
||||
|
||||
def predict(
|
||||
self,
|
||||
in_path: Path = Input(description="Input low-quality image"),
|
||||
num_steps: int = Input(
|
||||
choices=[1,2,3,4,5], description="Number of sampling steps.", default=1
|
||||
),
|
||||
chopping_size: int = Input(
|
||||
choices=[128, 256, 512], description="Chopping resolution", default=128
|
||||
),
|
||||
seed: int = Input(
|
||||
description="Random seed. Leave blank to randomize the seed.", default=12345
|
||||
),
|
||||
) -> Path:
|
||||
# setting configurations
|
||||
self.set_configs(num_steps, chopping_size, seed)
|
||||
|
||||
sampler = InvSamplerSR(self.configs)
|
||||
|
||||
out_dir = 'invsr_output'
|
||||
if os.path.exists(out_dir):
|
||||
shutil.rmtree(out_dir)
|
||||
sampler.inference(in_path, out_path=out_dir, bs=1)
|
||||
|
||||
out = "/tmp/out.png"
|
||||
shutil.copy(os.path.join(out_dir, os.listdir(out_dir)[0]), out)
|
||||
|
||||
return Path(out)
|
||||
Reference in New Issue
Block a user