add func of get_torch_dtype

This commit is contained in:
zsyOAOA
2024-12-14 15:08:58 +08:00
parent 9ca21e1b8b
commit 8dec48254e

View File

@@ -10,8 +10,6 @@ from pathlib import Path
from loguru import logger from loguru import logger
from omegaconf import OmegaConf from omegaconf import OmegaConf
from trainer import get_torch_dtype
from utils import util_net from utils import util_net
from utils import util_image from utils import util_image
from utils import util_common from utils import util_common
@@ -272,6 +270,16 @@ class InvSamplerSR(BaseSampler):
self.write_log(f"Processing done, enjoy the results in {str(out_path)}") self.write_log(f"Processing done, enjoy the results in {str(out_path)}")
def get_torch_dtype(torch_dtype: str):
if torch_dtype == 'torch.float16':
return torch.float16
elif torch_dtype == 'torch.bfloat16':
return torch.bfloat16
elif torch_dtype == 'torch.float32':
return torch.float32
else:
raise ValueError(f'Unexpected torch dtype:{torch_dtype}')
if __name__ == '__main__': if __name__ == '__main__':
pass pass