mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 06:14:22 +01:00
add func of get_torch_dtype
This commit is contained in:
@@ -10,8 +10,6 @@ from pathlib import Path
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from trainer import get_torch_dtype
|
|
||||||
|
|
||||||
from utils import util_net
|
from utils import util_net
|
||||||
from utils import util_image
|
from utils import util_image
|
||||||
from utils import util_common
|
from utils import util_common
|
||||||
@@ -272,6 +270,16 @@ class InvSamplerSR(BaseSampler):
|
|||||||
|
|
||||||
self.write_log(f"Processing done, enjoy the results in {str(out_path)}")
|
self.write_log(f"Processing done, enjoy the results in {str(out_path)}")
|
||||||
|
|
||||||
|
def get_torch_dtype(torch_dtype: str):
|
||||||
|
if torch_dtype == 'torch.float16':
|
||||||
|
return torch.float16
|
||||||
|
elif torch_dtype == 'torch.bfloat16':
|
||||||
|
return torch.bfloat16
|
||||||
|
elif torch_dtype == 'torch.float32':
|
||||||
|
return torch.float32
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unexpected torch dtype:{torch_dtype}')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user