From 8dec48254ed82d023c3b420629560e3381885954 Mon Sep 17 00:00:00 2001 From: zsyOAOA Date: Sat, 14 Dec 2024 15:08:58 +0800 Subject: [PATCH] add func of get_torch_dtype --- sampler_invsr.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sampler_invsr.py b/sampler_invsr.py index 3483779..4b7c23e 100644 --- a/sampler_invsr.py +++ b/sampler_invsr.py @@ -10,8 +10,6 @@ from pathlib import Path from loguru import logger from omegaconf import OmegaConf -from trainer import get_torch_dtype - from utils import util_net from utils import util_image from utils import util_common @@ -272,6 +270,16 @@ class InvSamplerSR(BaseSampler): self.write_log(f"Processing done, enjoy the results in {str(out_path)}") +def get_torch_dtype(torch_dtype: str): + if torch_dtype == 'torch.float16': + return torch.float16 + elif torch_dtype == 'torch.bfloat16': + return torch.bfloat16 + elif torch_dtype == 'torch.float32': + return torch.float32 + else: + raise ValueError(f'Unexpected torch dtype:{torch_dtype}') + if __name__ == '__main__': pass