mirror of
https://github.com/aljazceru/InvSR.git
synced 2025-12-17 06:14:22 +01:00
first commit
This commit is contained in:
213
basicsr/utils/logger.py
Normal file
213
basicsr/utils/logger.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
from .dist_util import get_dist_info, master_only
|
||||
|
||||
initialized_logger = {}
|
||||
|
||||
|
||||
class AvgTimer():
|
||||
|
||||
def __init__(self, window=200):
|
||||
self.window = window # average window
|
||||
self.current_time = 0
|
||||
self.total_time = 0
|
||||
self.count = 0
|
||||
self.avg_time = 0
|
||||
self.start()
|
||||
|
||||
def start(self):
|
||||
self.start_time = self.tic = time.time()
|
||||
|
||||
def record(self):
|
||||
self.count += 1
|
||||
self.toc = time.time()
|
||||
self.current_time = self.toc - self.tic
|
||||
self.total_time += self.current_time
|
||||
# calculate average time
|
||||
self.avg_time = self.total_time / self.count
|
||||
|
||||
# reset
|
||||
if self.count > self.window:
|
||||
self.count = 0
|
||||
self.total_time = 0
|
||||
|
||||
self.tic = time.time()
|
||||
|
||||
def get_current_time(self):
|
||||
return self.current_time
|
||||
|
||||
def get_avg_time(self):
|
||||
return self.avg_time
|
||||
|
||||
|
||||
class MessageLogger():
|
||||
"""Message logger for printing.
|
||||
|
||||
Args:
|
||||
opt (dict): Config. It contains the following keys:
|
||||
name (str): Exp name.
|
||||
logger (dict): Contains 'print_freq' (str) for logger interval.
|
||||
train (dict): Contains 'total_iter' (int) for total iters.
|
||||
use_tb_logger (bool): Use tensorboard logger.
|
||||
start_iter (int): Start iter. Default: 1.
|
||||
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, opt, start_iter=1, tb_logger=None):
|
||||
self.exp_name = opt['name']
|
||||
self.interval = opt['logger']['print_freq']
|
||||
self.start_iter = start_iter
|
||||
self.max_iters = opt['train']['total_iter']
|
||||
self.use_tb_logger = opt['logger']['use_tb_logger']
|
||||
self.tb_logger = tb_logger
|
||||
self.start_time = time.time()
|
||||
self.logger = get_root_logger()
|
||||
|
||||
def reset_start_time(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
@master_only
|
||||
def __call__(self, log_vars):
|
||||
"""Format logging message.
|
||||
|
||||
Args:
|
||||
log_vars (dict): It contains the following keys:
|
||||
epoch (int): Epoch number.
|
||||
iter (int): Current iter.
|
||||
lrs (list): List for learning rates.
|
||||
|
||||
time (float): Iter time.
|
||||
data_time (float): Data time for each iter.
|
||||
"""
|
||||
# epoch, iter, learning rates
|
||||
epoch = log_vars.pop('epoch')
|
||||
current_iter = log_vars.pop('iter')
|
||||
lrs = log_vars.pop('lrs')
|
||||
|
||||
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
|
||||
for v in lrs:
|
||||
message += f'{v:.3e},'
|
||||
message += ')] '
|
||||
|
||||
# time and estimated time
|
||||
if 'time' in log_vars.keys():
|
||||
iter_time = log_vars.pop('time')
|
||||
data_time = log_vars.pop('data_time')
|
||||
|
||||
total_time = time.time() - self.start_time
|
||||
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
|
||||
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
||||
message += f'[eta: {eta_str}, '
|
||||
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
|
||||
|
||||
# other items, especially losses
|
||||
for k, v in log_vars.items():
|
||||
message += f'{k}: {v:.4e} '
|
||||
# tensorboard logger
|
||||
if self.use_tb_logger and 'debug' not in self.exp_name:
|
||||
if k.startswith('l_'):
|
||||
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
|
||||
else:
|
||||
self.tb_logger.add_scalar(k, v, current_iter)
|
||||
self.logger.info(message)
|
||||
|
||||
|
||||
@master_only
|
||||
def init_tb_logger(log_dir):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
tb_logger = SummaryWriter(log_dir=log_dir)
|
||||
return tb_logger
|
||||
|
||||
|
||||
@master_only
|
||||
def init_wandb_logger(opt):
|
||||
"""We now only use wandb to sync tensorboard log."""
|
||||
import wandb
|
||||
logger = get_root_logger()
|
||||
|
||||
project = opt['logger']['wandb']['project']
|
||||
resume_id = opt['logger']['wandb'].get('resume_id')
|
||||
if resume_id:
|
||||
wandb_id = resume_id
|
||||
resume = 'allow'
|
||||
logger.warning(f'Resume wandb logger with id={wandb_id}.')
|
||||
else:
|
||||
wandb_id = wandb.util.generate_id()
|
||||
resume = 'never'
|
||||
|
||||
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
|
||||
|
||||
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
|
||||
|
||||
|
||||
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
||||
"""Get the root logger.
|
||||
|
||||
The logger will be initialized if it has not been initialized. By default a
|
||||
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
||||
also be added.
|
||||
|
||||
Args:
|
||||
logger_name (str): root logger name. Default: 'basicsr'.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the root logger.
|
||||
log_level (int): The root logger level. Note that only the process of
|
||||
rank 0 is affected, while other processes will set the level to
|
||||
"Error" and be silent most of the time.
|
||||
|
||||
Returns:
|
||||
logging.Logger: The root logger.
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
# if the logger has been initialized, just return it
|
||||
if logger_name in initialized_logger:
|
||||
return logger
|
||||
|
||||
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter(format_str))
|
||||
logger.addHandler(stream_handler)
|
||||
logger.propagate = False
|
||||
rank, _ = get_dist_info()
|
||||
if rank != 0:
|
||||
logger.setLevel('ERROR')
|
||||
elif log_file is not None:
|
||||
logger.setLevel(log_level)
|
||||
# add file handler
|
||||
file_handler = logging.FileHandler(log_file, 'w')
|
||||
file_handler.setFormatter(logging.Formatter(format_str))
|
||||
file_handler.setLevel(log_level)
|
||||
logger.addHandler(file_handler)
|
||||
initialized_logger[logger_name] = True
|
||||
return logger
|
||||
|
||||
|
||||
def get_env_info():
|
||||
"""Get environment information.
|
||||
|
||||
Currently, only log the software version.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from basicsr.version import __version__
|
||||
msg = r"""
|
||||
____ _ _____ ____
|
||||
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
|
||||
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
|
||||
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
|
||||
/_____/ \__,_//____//_/ \___//____//_/ |_|
|
||||
______ __ __ __ __
|
||||
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
|
||||
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
|
||||
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
|
||||
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
|
||||
"""
|
||||
msg += ('\nVersion Information: '
|
||||
f'\n\tBasicSR: {__version__}'
|
||||
f'\n\tPyTorch: {torch.__version__}'
|
||||
f'\n\tTorchVision: {torchvision.__version__}')
|
||||
return msg
|
||||
Reference in New Issue
Block a user