mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
* Overhauled the tracker system Separated the logging and saving capabilities Changed creation to be consistent and initializing behavior to be defined by a class initializer instead of in the training script Added class separation between different types of loaders and savers to make the system more verbose * Changed the saver system to only save the checkpoint once * Added better error handling for saving checkpoints * Fixed an error where wandb would error when passed arbitrary kwargs * Fixed variable naming issues for improved saver Added more logging during long pauses * Fixed which methods need to be dummy to immediatly return Added the ability to set whether you find unused parameters * Added more logging for when a wandb loader fails
36 lines
661 B
Python
36 lines
661 B
Python
import time
|
|
import importlib
|
|
|
|
# helper functions
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
# time helpers
|
|
|
|
class Timer:
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.last_time = time.time()
|
|
|
|
def elapsed(self):
|
|
return time.time() - self.last_time
|
|
|
|
# print helpers
|
|
|
|
def print_ribbon(s, symbol = '=', repeat = 40):
|
|
flank = symbol * repeat
|
|
return f'{flank} {s} {flank}'
|
|
|
|
# import helpers
|
|
|
|
def import_or_print_error(pkg_name, err_str = None):
|
|
try:
|
|
return importlib.import_module(pkg_name)
|
|
except ModuleNotFoundError as e:
|
|
if exists(err_str):
|
|
print(err_str)
|
|
exit()
|