Add the ability to auto restart the last run when started after a crash (#191)

* Added autoresume after crash functionality to the trackers

* Updated documentation

* Clarified what goes in the autorestart object

* Fixed style issues

Unraveled conditional block

Chnaged to using helper function to get step count
This commit is contained in:
Aidan Dempster
2022-07-08 16:35:40 -04:00
committed by GitHub
parent d7bc5fbedd
commit a71f693a26
6 changed files with 104 additions and 18 deletions

View File

@@ -30,6 +30,7 @@ Defines the configuration options for the decoder model. The unets defined above
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. | | `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. | | `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
| `learned_variance` | No | `True` | Whether to learn the variance. | | `learned_variance` | No | `True` | Whether to learn the variance. |
| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
Any parameter from the `Decoder` constructor can also be given here. Any parameter from the `Decoder` constructor can also be given here.
@@ -39,7 +40,8 @@ Settings for creation of the dataloaders.
| Option | Required | Default | Description | | Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- | | ------ | -------- | ------- | ----------- |
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. | | `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. | | `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
| `num_workers` | No | `4` | The number of workers used in the dataloader. | | `num_workers` | No | `4` | The number of workers used in the dataloader. |
| `batch_size` | No | `64` | The batch size. | | `batch_size` | No | `64` | The batch size. |
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. | | `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
@@ -106,6 +108,13 @@ Tracking is split up into three sections:
**Logging:** **Logging:**
All loggers have the following keys:
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `log_type` | Yes | N/A | The type of logger class to use. |
| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
If using `console` there is no further configuration than setting `log_type` to `console`. If using `console` there is no further configuration than setting `log_type` to `console`.
| Option | Required | Default | Description | | Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- | | ------ | -------- | ------- | ----------- |
@@ -119,10 +128,15 @@ If using `wandb`
| `wandb_project` | Yes | N/A | The wandb project save the run to. | | `wandb_project` | Yes | N/A | The wandb project save the run to. |
| `wandb_run_name` | No | `None` | The wandb run name. | | `wandb_run_name` | No | `None` | The wandb run name. |
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. | | `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
| `wandb_resume` | No | `False` | Whether to resume an old run. |
**Loading:** **Loading:**
All loaders have the following keys:
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | The type of loader class to use. |
| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
If using `local` If using `local`
| Option | Required | Default | Description | | Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- | | ------ | -------- | ------- | ----------- |

View File

@@ -20,7 +20,7 @@
}, },
"data": { "data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -", "webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"embeddings_url": "s3://bucket/embeddings/path/", "img_embeddings_url": "s3://bucket/img_embeddings/path/",
"num_workers": 4, "num_workers": 4,
"batch_size": 64, "batch_size": 64,
"start_shard": 0, "start_shard": 0,

View File

@@ -1,5 +1,6 @@
import urllib.request import urllib.request
import os import os
import json
from pathlib import Path from pathlib import Path
import shutil import shutil
from itertools import zip_longest from itertools import zip_longest
@@ -37,14 +38,17 @@ class BaseLogger:
data_path (str): A file path for storing temporary data. data_path (str): A file path for storing temporary data.
verbose (bool): Whether of not to always print logs to the console. verbose (bool): Whether of not to always print logs to the console.
""" """
def __init__(self, data_path: str, verbose: bool = False, **kwargs): def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
self.data_path = Path(data_path) self.data_path = Path(data_path)
self.resume = resume
self.auto_resume = auto_resume
self.verbose = verbose self.verbose = verbose
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None: def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
""" """
Initializes the logger. Initializes the logger.
Errors if the logger is invalid. Errors if the logger is invalid.
full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
""" """
raise NotImplementedError raise NotImplementedError
@@ -60,6 +64,14 @@ class BaseLogger:
def log_error(self, error_string, **kwargs) -> None: def log_error(self, error_string, **kwargs) -> None:
raise NotImplementedError raise NotImplementedError
def get_resume_data(self, **kwargs) -> dict:
"""
Sets tracker attributes that along with { "resume": True } will be used to resume training.
It is assumed that after init is called this data will be complete.
If the logger does not have any resume functionality, it should return an empty dict.
"""
raise NotImplementedError
class ConsoleLogger(BaseLogger): class ConsoleLogger(BaseLogger):
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None: def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
print("Logging to console") print("Logging to console")
@@ -76,6 +88,9 @@ class ConsoleLogger(BaseLogger):
def log_error(self, error_string, **kwargs) -> None: def log_error(self, error_string, **kwargs) -> None:
print(error_string) print(error_string)
def get_resume_data(self, **kwargs) -> dict:
return {}
class WandbLogger(BaseLogger): class WandbLogger(BaseLogger):
""" """
Logs to a wandb run. Logs to a wandb run.
@@ -85,7 +100,6 @@ class WandbLogger(BaseLogger):
wandb_project (str): The wandb project to log to. wandb_project (str): The wandb project to log to.
wandb_run_id (str): The wandb run id to resume. wandb_run_id (str): The wandb run id to resume.
wandb_run_name (str): The wandb run name to use. wandb_run_name (str): The wandb run name to use.
wandb_resume (bool): Whether to resume a wandb run.
""" """
def __init__(self, def __init__(self,
data_path: str, data_path: str,
@@ -93,7 +107,6 @@ class WandbLogger(BaseLogger):
wandb_project: str, wandb_project: str,
wandb_run_id: Optional[str] = None, wandb_run_id: Optional[str] = None,
wandb_run_name: Optional[str] = None, wandb_run_name: Optional[str] = None,
wandb_resume: bool = False,
**kwargs **kwargs
): ):
super().__init__(data_path, **kwargs) super().__init__(data_path, **kwargs)
@@ -101,7 +114,6 @@ class WandbLogger(BaseLogger):
self.project = wandb_project self.project = wandb_project
self.run_id = wandb_run_id self.run_id = wandb_run_id
self.run_name = wandb_run_name self.run_name = wandb_run_name
self.resume = wandb_resume
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None: def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
assert self.entity is not None, "wandb_entity must be specified for wandb logger" assert self.entity is not None, "wandb_entity must be specified for wandb logger"
@@ -149,6 +161,14 @@ class WandbLogger(BaseLogger):
print(error_string) print(error_string)
self.wandb.log({"error": error_string, **kwargs}, step=step) self.wandb.log({"error": error_string, **kwargs}, step=step)
def get_resume_data(self, **kwargs) -> dict:
# In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
return {
"entity": self.entity,
"project": self.project,
"run_id": self.wandb.run.id
}
logger_type_map = { logger_type_map = {
'console': ConsoleLogger, 'console': ConsoleLogger,
'wandb': WandbLogger, 'wandb': WandbLogger,
@@ -168,8 +188,9 @@ class BaseLoader:
Parameters: Parameters:
data_path (str): A file path for storing temporary data. data_path (str): A file path for storing temporary data.
""" """
def __init__(self, data_path: str, **kwargs): def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
self.data_path = Path(data_path) self.data_path = Path(data_path)
self.only_auto_resume = only_auto_resume
def init(self, logger: BaseLogger, **kwargs) -> None: def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError raise NotImplementedError
@@ -304,6 +325,10 @@ class LocalSaver(BaseSaver):
def save_file(self, local_path: str, save_path: str, **kwargs) -> None: def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
# Copy the file to save_path # Copy the file to save_path
save_path_file_name = Path(save_path).name save_path_file_name = Path(save_path).name
# Make sure parent directory exists
save_path_parent = Path(save_path).parent
if not save_path_parent.exists():
save_path_parent.mkdir(parents=True)
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}") print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
shutil.copy(local_path, save_path) shutil.copy(local_path, save_path)
@@ -385,11 +410,7 @@ class Tracker:
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False): def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
self.data_path = Path(data_path) self.data_path = Path(data_path)
if not dummy_mode: if not dummy_mode:
if overwrite_data_path: if not overwrite_data_path:
if self.data_path.exists():
shutil.rmtree(self.data_path)
self.data_path.mkdir(parents=True)
else:
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.' assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
if not self.data_path.exists(): if not self.data_path.exists():
self.data_path.mkdir(parents=True) self.data_path.mkdir(parents=True)
@@ -398,7 +419,46 @@ class Tracker:
self.savers: List[BaseSaver]= [] self.savers: List[BaseSaver]= []
self.dummy_mode = dummy_mode self.dummy_mode = dummy_mode
def _load_auto_resume(self) -> bool:
# If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
if not self.auto_resume_path.exists():
if self.logger.auto_resume:
print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
return False
# Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
if not self.logger.auto_resume:
print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
self.auto_resume_path.unlink()
return False
# Otherwise we read the json into a dictionary will will override parts of logger.__dict__
with open(self.auto_resume_path, 'r') as f:
auto_resume_dict = json.load(f)
# Check if the logger is of the same type as the autoresume save
if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
# Then we are ready to override the logger with the autoresume save
self.logger.__dict__["resume"] = True
print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
self.logger.__dict__.update(auto_resume_dict)
return True
def _save_auto_resume(self):
# Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
auto_resume_dict = self.logger.get_resume_data()
auto_resume_dict['logger_type'] = self.logger.__class__.__name__
with open(self.auto_resume_path, 'w') as f:
json.dump(auto_resume_dict, f)
def init(self, full_config: BaseModel, extra_config: dict): def init(self, full_config: BaseModel, extra_config: dict):
self.auto_resume_path = self.data_path / 'auto_resume.json'
# Check for resuming the run
self.did_auto_resume = self._load_auto_resume()
if self.did_auto_resume:
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
print(f"New logger config: {self.logger.__dict__}")
assert self.logger is not None, '`logger` must be set before `init` is called' assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode: if self.dummy_mode:
# The only thing we need is a loader # The only thing we need is a loader
@@ -406,12 +466,17 @@ class Tracker:
self.loader.init(self.logger) self.loader.init(self.logger)
return return
assert len(self.savers) > 0, '`savers` must be set before `init` is called' assert len(self.savers) > 0, '`savers` must be set before `init` is called'
self.logger.init(full_config, extra_config) self.logger.init(full_config, extra_config)
if self.loader is not None: if self.loader is not None:
self.loader.init(self.logger) self.loader.init(self.logger)
for saver in self.savers: for saver in self.savers:
saver.init(self.logger) saver.init(self.logger)
if self.logger.auto_resume:
# Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
self._save_auto_resume()
def add_logger(self, logger: BaseLogger): def add_logger(self, logger: BaseLogger):
self.logger = logger self.logger = logger
@@ -503,11 +568,16 @@ class Tracker:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs) self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}') print(f'Error saving checkpoint: {e}')
@property
def can_recall(self):
# Defines whether a recall can be performed.
return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
def recall(self): def recall(self):
if self.loader is not None: if self.can_recall:
return self.loader.recall() return self.loader.recall()
else: else:
raise ValueError('No loader specified') raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')

View File

@@ -47,6 +47,8 @@ class TrainSplitConfig(BaseModel):
class TrackerLogConfig(BaseModel): class TrackerLogConfig(BaseModel):
log_type: str = 'console' log_type: str = 'console'
resume: bool = False # For logs that are saved to unique locations, resume a previous run
auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
verbose: bool = False verbose: bool = False
class Config: class Config:
@@ -59,6 +61,7 @@ class TrackerLogConfig(BaseModel):
class TrackerLoadConfig(BaseModel): class TrackerLoadConfig(BaseModel):
load_from: Optional[str] = None load_from: Optional[str] = None
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
class Config: class Config:
extra = "allow" extra = "allow"

View File

@@ -509,7 +509,6 @@ class DecoderTrainer(nn.Module):
self.register_buffer('steps', torch.tensor([0] * self.num_unets)) self.register_buffer('steps', torch.tensor([0] * self.num_unets))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
schedulers = list(self.accelerator.prepare(*schedulers))
self.decoder = decoder self.decoder = decoder

View File

@@ -289,9 +289,9 @@ def train(
sample = 0 sample = 0
samples_seen = 0 samples_seen = 0
val_sample = 0 val_sample = 0
step = lambda: int(trainer.step.item()) step = lambda: int(trainer.num_steps_taken(unet_number=1))
if tracker.loader is not None: if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer) start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
if next_task == 'train': if next_task == 'train':
sample = recalled_sample sample = recalled_sample