Distributed Training of the Decoder (#121)

* Converted decoder trainer to use accelerate

* Fixed issue where metric evaluation would hang on distributed mode

* Implemented functional saving
Loading still fails due to some issue with the optimizer

* Fixed issue with loading decoders

* Fixed issue with tracker config

* Fixed issue with amp
Updated logging to be more logical

* Saving checkpoint now saves position in training as well
Fixed an issue with running out of gpu space due to loading weights into the gpu twice

* Fixed ema for distributed training

* Fixed isue where get_pkg_version was reintroduced

* Changed decoder trainer to upload config as a file

Fixed issue where loading best would error
This commit is contained in:
Aidan Dempster
2022-06-19 12:25:54 -04:00
committed by GitHub
parent e37072a48c
commit 58892135d9
7 changed files with 331 additions and 207 deletions

View File

@@ -17,15 +17,15 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
# load state dict functions
# load file functions
def load_wandb_state_dict(run_path, file_path, **kwargs):
def load_wandb_file(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return torch.load(file_reference.name)
return file_reference.name
def load_local_state_dict(file_path, **kwargs):
return torch.load(file_path)
def load_local_file(file_path, **kwargs):
return file_path
# base class
@@ -55,12 +55,43 @@ class BaseTracker(nn.Module):
"""
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
if recall_source == 'wandb':
return load_wandb_state_dict(*args, **kwargs)
return torch.load(load_wandb_file(*args, **kwargs))
elif recall_source == 'local':
return load_local_state_dict(*args, **kwargs)
return torch.load(load_local_file(*args, **kwargs))
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
def save_file(self, file_path, **kwargs):
raise NotImplementedError
def recall_file(self, recall_source, *args, **kwargs):
if recall_source == 'wandb':
return load_wandb_file(*args, **kwargs)
elif recall_source == 'local':
return load_local_file(*args, **kwargs)
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
# Tracker that no-ops all calls except for recall
class DummyTracker(BaseTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def init(self, config, **kwargs):
pass
def log(self, log, **kwargs):
pass
def log_images(self, images, **kwargs):
pass
def save_state_dict(self, state_dict, relative_path, **kwargs):
pass
def save_file(self, file_path, **kwargs):
pass
# basic stdout class
@@ -76,6 +107,10 @@ class ConsoleTracker(BaseTracker):
def save_state_dict(self, state_dict, relative_path, **kwargs):
torch.save(state_dict, str(self.data_path / relative_path))
def save_file(self, file_path, **kwargs):
# This is a no-op for local file systems since it is already saved locally
pass
# basic wandb class
@@ -107,3 +142,11 @@ class WandbTracker(BaseTracker):
full_path = str(self.data_path / relative_path)
torch.save(state_dict, full_path)
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
def save_file(self, file_path, base_path=None, **kwargs):
"""
Uploads a file from disk to wandb
"""
if base_path is None:
base_path = self.data_path
self.wandb.save(str(file_path), base_path = str(base_path))