mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 01:34:19 +01:00
Add the ability to auto restart the last run when started after a crash (#191)
* Added autoresume after crash functionality to the trackers * Updated documentation * Clarified what goes in the autorestart object * Fixed style issues Unraveled conditional block Chnaged to using helper function to get step count
This commit is contained in:
@@ -30,6 +30,7 @@ Defines the configuration options for the decoder model. The unets defined above
|
||||
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
|
||||
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
|
||||
| `learned_variance` | No | `True` | Whether to learn the variance. |
|
||||
| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
|
||||
|
||||
Any parameter from the `Decoder` constructor can also be given here.
|
||||
|
||||
@@ -39,7 +40,8 @@ Settings for creation of the dataloaders.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
|
||||
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
|
||||
| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
|
||||
| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
|
||||
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
|
||||
| `batch_size` | No | `64` | The batch size. |
|
||||
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
|
||||
@@ -106,6 +108,13 @@ Tracking is split up into three sections:
|
||||
|
||||
**Logging:**
|
||||
|
||||
All loggers have the following keys:
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `log_type` | Yes | N/A | The type of logger class to use. |
|
||||
| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
|
||||
| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
|
||||
|
||||
If using `console` there is no further configuration than setting `log_type` to `console`.
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
@@ -119,10 +128,15 @@ If using `wandb`
|
||||
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
|
||||
| `wandb_run_name` | No | `None` | The wandb run name. |
|
||||
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
|
||||
| `wandb_resume` | No | `False` | Whether to resume an old run. |
|
||||
|
||||
**Loading:**
|
||||
|
||||
All loaders have the following keys:
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
| `load_from` | Yes | N/A | The type of loader class to use. |
|
||||
| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
|
||||
|
||||
If using `local`
|
||||
| Option | Required | Default | Description |
|
||||
| ------ | -------- | ------- | ----------- |
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
},
|
||||
"data": {
|
||||
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
|
||||
"embeddings_url": "s3://bucket/embeddings/path/",
|
||||
"img_embeddings_url": "s3://bucket/img_embeddings/path/",
|
||||
"num_workers": 4,
|
||||
"batch_size": 64,
|
||||
"start_shard": 0,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import urllib.request
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from itertools import zip_longest
|
||||
@@ -37,14 +38,17 @@ class BaseLogger:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
verbose (bool): Whether of not to always print logs to the console.
|
||||
"""
|
||||
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
|
||||
def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.resume = resume
|
||||
self.auto_resume = auto_resume
|
||||
self.verbose = verbose
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
"""
|
||||
Initializes the logger.
|
||||
Errors if the logger is invalid.
|
||||
full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -60,6 +64,14 @@ class BaseLogger:
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
"""
|
||||
Sets tracker attributes that along with { "resume": True } will be used to resume training.
|
||||
It is assumed that after init is called this data will be complete.
|
||||
If the logger does not have any resume functionality, it should return an empty dict.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class ConsoleLogger(BaseLogger):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
print("Logging to console")
|
||||
@@ -76,6 +88,9 @@ class ConsoleLogger(BaseLogger):
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
print(error_string)
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
return {}
|
||||
|
||||
class WandbLogger(BaseLogger):
|
||||
"""
|
||||
Logs to a wandb run.
|
||||
@@ -85,7 +100,6 @@ class WandbLogger(BaseLogger):
|
||||
wandb_project (str): The wandb project to log to.
|
||||
wandb_run_id (str): The wandb run id to resume.
|
||||
wandb_run_name (str): The wandb run name to use.
|
||||
wandb_resume (bool): Whether to resume a wandb run.
|
||||
"""
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
@@ -93,7 +107,6 @@ class WandbLogger(BaseLogger):
|
||||
wandb_project: str,
|
||||
wandb_run_id: Optional[str] = None,
|
||||
wandb_run_name: Optional[str] = None,
|
||||
wandb_resume: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
@@ -101,7 +114,6 @@ class WandbLogger(BaseLogger):
|
||||
self.project = wandb_project
|
||||
self.run_id = wandb_run_id
|
||||
self.run_name = wandb_run_name
|
||||
self.resume = wandb_resume
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
||||
@@ -149,6 +161,14 @@ class WandbLogger(BaseLogger):
|
||||
print(error_string)
|
||||
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
||||
|
||||
def get_resume_data(self, **kwargs) -> dict:
|
||||
# In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
|
||||
return {
|
||||
"entity": self.entity,
|
||||
"project": self.project,
|
||||
"run_id": self.wandb.run.id
|
||||
}
|
||||
|
||||
logger_type_map = {
|
||||
'console': ConsoleLogger,
|
||||
'wandb': WandbLogger,
|
||||
@@ -168,8 +188,9 @@ class BaseLoader:
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
"""
|
||||
def __init__(self, data_path: str, **kwargs):
|
||||
def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.only_auto_resume = only_auto_resume
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
@@ -304,6 +325,10 @@ class LocalSaver(BaseSaver):
|
||||
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
||||
# Copy the file to save_path
|
||||
save_path_file_name = Path(save_path).name
|
||||
# Make sure parent directory exists
|
||||
save_path_parent = Path(save_path).parent
|
||||
if not save_path_parent.exists():
|
||||
save_path_parent.mkdir(parents=True)
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
||||
shutil.copy(local_path, save_path)
|
||||
|
||||
@@ -385,11 +410,7 @@ class Tracker:
|
||||
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
||||
self.data_path = Path(data_path)
|
||||
if not dummy_mode:
|
||||
if overwrite_data_path:
|
||||
if self.data_path.exists():
|
||||
shutil.rmtree(self.data_path)
|
||||
self.data_path.mkdir(parents=True)
|
||||
else:
|
||||
if not overwrite_data_path:
|
||||
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
@@ -398,7 +419,46 @@ class Tracker:
|
||||
self.savers: List[BaseSaver]= []
|
||||
self.dummy_mode = dummy_mode
|
||||
|
||||
def _load_auto_resume(self) -> bool:
|
||||
# If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
|
||||
if not self.auto_resume_path.exists():
|
||||
if self.logger.auto_resume:
|
||||
print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
|
||||
return False
|
||||
|
||||
# Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
|
||||
if not self.logger.auto_resume:
|
||||
print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
|
||||
self.auto_resume_path.unlink()
|
||||
return False
|
||||
|
||||
# Otherwise we read the json into a dictionary will will override parts of logger.__dict__
|
||||
with open(self.auto_resume_path, 'r') as f:
|
||||
auto_resume_dict = json.load(f)
|
||||
# Check if the logger is of the same type as the autoresume save
|
||||
if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
|
||||
raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
|
||||
# Then we are ready to override the logger with the autoresume save
|
||||
self.logger.__dict__["resume"] = True
|
||||
print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
|
||||
self.logger.__dict__.update(auto_resume_dict)
|
||||
return True
|
||||
|
||||
def _save_auto_resume(self):
|
||||
# Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
|
||||
auto_resume_dict = self.logger.get_resume_data()
|
||||
auto_resume_dict['logger_type'] = self.logger.__class__.__name__
|
||||
with open(self.auto_resume_path, 'w') as f:
|
||||
json.dump(auto_resume_dict, f)
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict):
|
||||
self.auto_resume_path = self.data_path / 'auto_resume.json'
|
||||
# Check for resuming the run
|
||||
self.did_auto_resume = self._load_auto_resume()
|
||||
if self.did_auto_resume:
|
||||
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
|
||||
print(f"New logger config: {self.logger.__dict__}")
|
||||
|
||||
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||
if self.dummy_mode:
|
||||
# The only thing we need is a loader
|
||||
@@ -406,12 +466,17 @@ class Tracker:
|
||||
self.loader.init(self.logger)
|
||||
return
|
||||
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
||||
|
||||
self.logger.init(full_config, extra_config)
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
for saver in self.savers:
|
||||
saver.init(self.logger)
|
||||
|
||||
if self.logger.auto_resume:
|
||||
# Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
|
||||
self._save_auto_resume()
|
||||
|
||||
def add_logger(self, logger: BaseLogger):
|
||||
self.logger = logger
|
||||
|
||||
@@ -503,11 +568,16 @@ class Tracker:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
|
||||
@property
|
||||
def can_recall(self):
|
||||
# Defines whether a recall can be performed.
|
||||
return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
|
||||
|
||||
def recall(self):
|
||||
if self.loader is not None:
|
||||
if self.can_recall:
|
||||
return self.loader.recall()
|
||||
else:
|
||||
raise ValueError('No loader specified')
|
||||
raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')
|
||||
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@ class TrainSplitConfig(BaseModel):
|
||||
|
||||
class TrackerLogConfig(BaseModel):
|
||||
log_type: str = 'console'
|
||||
resume: bool = False # For logs that are saved to unique locations, resume a previous run
|
||||
auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
|
||||
verbose: bool = False
|
||||
|
||||
class Config:
|
||||
@@ -59,6 +61,7 @@ class TrackerLogConfig(BaseModel):
|
||||
|
||||
class TrackerLoadConfig(BaseModel):
|
||||
load_from: Optional[str] = None
|
||||
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
@@ -509,7 +509,6 @@ class DecoderTrainer(nn.Module):
|
||||
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||
|
||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||
schedulers = list(self.accelerator.prepare(*schedulers))
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
|
||||
@@ -289,9 +289,9 @@ def train(
|
||||
sample = 0
|
||||
samples_seen = 0
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.step.item())
|
||||
step = lambda: int(trainer.num_steps_taken(unet_number=1))
|
||||
|
||||
if tracker.loader is not None:
|
||||
if tracker.can_recall:
|
||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||
if next_task == 'train':
|
||||
sample = recalled_sample
|
||||
|
||||
Reference in New Issue
Block a user