mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-16 22:04:20 +01:00
add func of get_torch_dtype
This commit is contained in:
@@ -10,8 +10,6 @@ from pathlib import Path
|
||||
from loguru import logger
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from trainer import get_torch_dtype
|
||||
|
||||
from utils import util_net
|
||||
from utils import util_image
|
||||
from utils import util_common
|
||||
@@ -272,6 +270,16 @@ class InvSamplerSR(BaseSampler):
|
||||
|
||||
self.write_log(f"Processing done, enjoy the results in {str(out_path)}")
|
||||
|
||||
def get_torch_dtype(torch_dtype: str):
|
||||
if torch_dtype == 'torch.float16':
|
||||
return torch.float16
|
||||
elif torch_dtype == 'torch.bfloat16':
|
||||
return torch.bfloat16
|
||||
elif torch_dtype == 'torch.float32':
|
||||
return torch.float32
|
||||
else:
|
||||
raise ValueError(f'Unexpected torch dtype:{torch_dtype}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user