mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Add the ability to auto restart the last run when started after a crash (#191)
* Added autoresume after crash functionality to the trackers * Updated documentation * Clarified what goes in the autorestart object * Fixed style issues Unraveled conditional block Chnaged to using helper function to get step count
This commit is contained in:
@@ -30,6 +30,7 @@ Defines the configuration options for the decoder model. The unets defined above
|
|||||||
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
|
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
|
||||||
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
|
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
|
||||||
| `learned_variance` | No | `True` | Whether to learn the variance. |
|
| `learned_variance` | No | `True` | Whether to learn the variance. |
|
||||||
|
| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
|
||||||
|
|
||||||
Any parameter from the `Decoder` constructor can also be given here.
|
Any parameter from the `Decoder` constructor can also be given here.
|
||||||
|
|
||||||
@@ -39,7 +40,8 @@ Settings for creation of the dataloaders.
|
|||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
|
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
|
||||||
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
|
| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
|
||||||
|
| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
|
||||||
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
|
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
|
||||||
| `batch_size` | No | `64` | The batch size. |
|
| `batch_size` | No | `64` | The batch size. |
|
||||||
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
|
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
|
||||||
@@ -106,6 +108,13 @@ Tracking is split up into three sections:
|
|||||||
|
|
||||||
**Logging:**
|
**Logging:**
|
||||||
|
|
||||||
|
All loggers have the following keys:
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `log_type` | Yes | N/A | The type of logger class to use. |
|
||||||
|
| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
|
||||||
|
| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
|
||||||
|
|
||||||
If using `console` there is no further configuration than setting `log_type` to `console`.
|
If using `console` there is no further configuration than setting `log_type` to `console`.
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
@@ -119,10 +128,15 @@ If using `wandb`
|
|||||||
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
|
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
|
||||||
| `wandb_run_name` | No | `None` | The wandb run name. |
|
| `wandb_run_name` | No | `None` | The wandb run name. |
|
||||||
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
|
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
|
||||||
| `wandb_resume` | No | `False` | Whether to resume an old run. |
|
|
||||||
|
|
||||||
**Loading:**
|
**Loading:**
|
||||||
|
|
||||||
|
All loaders have the following keys:
|
||||||
|
| Option | Required | Default | Description |
|
||||||
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
| `load_from` | Yes | N/A | The type of loader class to use. |
|
||||||
|
| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
|
||||||
|
|
||||||
If using `local`
|
If using `local`
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
|
|||||||
@@ -20,7 +20,7 @@
|
|||||||
},
|
},
|
||||||
"data": {
|
"data": {
|
||||||
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
|
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
|
||||||
"embeddings_url": "s3://bucket/embeddings/path/",
|
"img_embeddings_url": "s3://bucket/img_embeddings/path/",
|
||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
"batch_size": 64,
|
"batch_size": 64,
|
||||||
"start_shard": 0,
|
"start_shard": 0,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import urllib.request
|
import urllib.request
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
@@ -37,14 +38,17 @@ class BaseLogger:
|
|||||||
data_path (str): A file path for storing temporary data.
|
data_path (str): A file path for storing temporary data.
|
||||||
verbose (bool): Whether of not to always print logs to the console.
|
verbose (bool): Whether of not to always print logs to the console.
|
||||||
"""
|
"""
|
||||||
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
|
def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
|
||||||
self.data_path = Path(data_path)
|
self.data_path = Path(data_path)
|
||||||
|
self.resume = resume
|
||||||
|
self.auto_resume = auto_resume
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the logger.
|
Initializes the logger.
|
||||||
Errors if the logger is invalid.
|
Errors if the logger is invalid.
|
||||||
|
full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -60,6 +64,14 @@ class BaseLogger:
|
|||||||
def log_error(self, error_string, **kwargs) -> None:
|
def log_error(self, error_string, **kwargs) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_resume_data(self, **kwargs) -> dict:
|
||||||
|
"""
|
||||||
|
Sets tracker attributes that along with { "resume": True } will be used to resume training.
|
||||||
|
It is assumed that after init is called this data will be complete.
|
||||||
|
If the logger does not have any resume functionality, it should return an empty dict.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
class ConsoleLogger(BaseLogger):
|
class ConsoleLogger(BaseLogger):
|
||||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||||
print("Logging to console")
|
print("Logging to console")
|
||||||
@@ -76,6 +88,9 @@ class ConsoleLogger(BaseLogger):
|
|||||||
def log_error(self, error_string, **kwargs) -> None:
|
def log_error(self, error_string, **kwargs) -> None:
|
||||||
print(error_string)
|
print(error_string)
|
||||||
|
|
||||||
|
def get_resume_data(self, **kwargs) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
class WandbLogger(BaseLogger):
|
class WandbLogger(BaseLogger):
|
||||||
"""
|
"""
|
||||||
Logs to a wandb run.
|
Logs to a wandb run.
|
||||||
@@ -85,7 +100,6 @@ class WandbLogger(BaseLogger):
|
|||||||
wandb_project (str): The wandb project to log to.
|
wandb_project (str): The wandb project to log to.
|
||||||
wandb_run_id (str): The wandb run id to resume.
|
wandb_run_id (str): The wandb run id to resume.
|
||||||
wandb_run_name (str): The wandb run name to use.
|
wandb_run_name (str): The wandb run name to use.
|
||||||
wandb_resume (bool): Whether to resume a wandb run.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data_path: str,
|
data_path: str,
|
||||||
@@ -93,7 +107,6 @@ class WandbLogger(BaseLogger):
|
|||||||
wandb_project: str,
|
wandb_project: str,
|
||||||
wandb_run_id: Optional[str] = None,
|
wandb_run_id: Optional[str] = None,
|
||||||
wandb_run_name: Optional[str] = None,
|
wandb_run_name: Optional[str] = None,
|
||||||
wandb_resume: bool = False,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(data_path, **kwargs)
|
super().__init__(data_path, **kwargs)
|
||||||
@@ -101,7 +114,6 @@ class WandbLogger(BaseLogger):
|
|||||||
self.project = wandb_project
|
self.project = wandb_project
|
||||||
self.run_id = wandb_run_id
|
self.run_id = wandb_run_id
|
||||||
self.run_name = wandb_run_name
|
self.run_name = wandb_run_name
|
||||||
self.resume = wandb_resume
|
|
||||||
|
|
||||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||||
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
||||||
@@ -149,6 +161,14 @@ class WandbLogger(BaseLogger):
|
|||||||
print(error_string)
|
print(error_string)
|
||||||
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
||||||
|
|
||||||
|
def get_resume_data(self, **kwargs) -> dict:
|
||||||
|
# In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
|
||||||
|
return {
|
||||||
|
"entity": self.entity,
|
||||||
|
"project": self.project,
|
||||||
|
"run_id": self.wandb.run.id
|
||||||
|
}
|
||||||
|
|
||||||
logger_type_map = {
|
logger_type_map = {
|
||||||
'console': ConsoleLogger,
|
'console': ConsoleLogger,
|
||||||
'wandb': WandbLogger,
|
'wandb': WandbLogger,
|
||||||
@@ -168,8 +188,9 @@ class BaseLoader:
|
|||||||
Parameters:
|
Parameters:
|
||||||
data_path (str): A file path for storing temporary data.
|
data_path (str): A file path for storing temporary data.
|
||||||
"""
|
"""
|
||||||
def __init__(self, data_path: str, **kwargs):
|
def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
|
||||||
self.data_path = Path(data_path)
|
self.data_path = Path(data_path)
|
||||||
|
self.only_auto_resume = only_auto_resume
|
||||||
|
|
||||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -304,6 +325,10 @@ class LocalSaver(BaseSaver):
|
|||||||
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
||||||
# Copy the file to save_path
|
# Copy the file to save_path
|
||||||
save_path_file_name = Path(save_path).name
|
save_path_file_name = Path(save_path).name
|
||||||
|
# Make sure parent directory exists
|
||||||
|
save_path_parent = Path(save_path).parent
|
||||||
|
if not save_path_parent.exists():
|
||||||
|
save_path_parent.mkdir(parents=True)
|
||||||
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
||||||
shutil.copy(local_path, save_path)
|
shutil.copy(local_path, save_path)
|
||||||
|
|
||||||
@@ -385,11 +410,7 @@ class Tracker:
|
|||||||
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
||||||
self.data_path = Path(data_path)
|
self.data_path = Path(data_path)
|
||||||
if not dummy_mode:
|
if not dummy_mode:
|
||||||
if overwrite_data_path:
|
if not overwrite_data_path:
|
||||||
if self.data_path.exists():
|
|
||||||
shutil.rmtree(self.data_path)
|
|
||||||
self.data_path.mkdir(parents=True)
|
|
||||||
else:
|
|
||||||
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
||||||
if not self.data_path.exists():
|
if not self.data_path.exists():
|
||||||
self.data_path.mkdir(parents=True)
|
self.data_path.mkdir(parents=True)
|
||||||
@@ -398,7 +419,46 @@ class Tracker:
|
|||||||
self.savers: List[BaseSaver]= []
|
self.savers: List[BaseSaver]= []
|
||||||
self.dummy_mode = dummy_mode
|
self.dummy_mode = dummy_mode
|
||||||
|
|
||||||
|
def _load_auto_resume(self) -> bool:
|
||||||
|
# If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
|
||||||
|
if not self.auto_resume_path.exists():
|
||||||
|
if self.logger.auto_resume:
|
||||||
|
print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
|
||||||
|
if not self.logger.auto_resume:
|
||||||
|
print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
|
||||||
|
self.auto_resume_path.unlink()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Otherwise we read the json into a dictionary will will override parts of logger.__dict__
|
||||||
|
with open(self.auto_resume_path, 'r') as f:
|
||||||
|
auto_resume_dict = json.load(f)
|
||||||
|
# Check if the logger is of the same type as the autoresume save
|
||||||
|
if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
|
||||||
|
raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
|
||||||
|
# Then we are ready to override the logger with the autoresume save
|
||||||
|
self.logger.__dict__["resume"] = True
|
||||||
|
print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
|
||||||
|
self.logger.__dict__.update(auto_resume_dict)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _save_auto_resume(self):
|
||||||
|
# Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
|
||||||
|
auto_resume_dict = self.logger.get_resume_data()
|
||||||
|
auto_resume_dict['logger_type'] = self.logger.__class__.__name__
|
||||||
|
with open(self.auto_resume_path, 'w') as f:
|
||||||
|
json.dump(auto_resume_dict, f)
|
||||||
|
|
||||||
def init(self, full_config: BaseModel, extra_config: dict):
|
def init(self, full_config: BaseModel, extra_config: dict):
|
||||||
|
self.auto_resume_path = self.data_path / 'auto_resume.json'
|
||||||
|
# Check for resuming the run
|
||||||
|
self.did_auto_resume = self._load_auto_resume()
|
||||||
|
if self.did_auto_resume:
|
||||||
|
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
|
||||||
|
print(f"New logger config: {self.logger.__dict__}")
|
||||||
|
|
||||||
assert self.logger is not None, '`logger` must be set before `init` is called'
|
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||||
if self.dummy_mode:
|
if self.dummy_mode:
|
||||||
# The only thing we need is a loader
|
# The only thing we need is a loader
|
||||||
@@ -406,12 +466,17 @@ class Tracker:
|
|||||||
self.loader.init(self.logger)
|
self.loader.init(self.logger)
|
||||||
return
|
return
|
||||||
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
||||||
|
|
||||||
self.logger.init(full_config, extra_config)
|
self.logger.init(full_config, extra_config)
|
||||||
if self.loader is not None:
|
if self.loader is not None:
|
||||||
self.loader.init(self.logger)
|
self.loader.init(self.logger)
|
||||||
for saver in self.savers:
|
for saver in self.savers:
|
||||||
saver.init(self.logger)
|
saver.init(self.logger)
|
||||||
|
|
||||||
|
if self.logger.auto_resume:
|
||||||
|
# Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
|
||||||
|
self._save_auto_resume()
|
||||||
|
|
||||||
def add_logger(self, logger: BaseLogger):
|
def add_logger(self, logger: BaseLogger):
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
@@ -503,11 +568,16 @@ class Tracker:
|
|||||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||||
print(f'Error saving checkpoint: {e}')
|
print(f'Error saving checkpoint: {e}')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_recall(self):
|
||||||
|
# Defines whether a recall can be performed.
|
||||||
|
return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
|
||||||
|
|
||||||
def recall(self):
|
def recall(self):
|
||||||
if self.loader is not None:
|
if self.can_recall:
|
||||||
return self.loader.recall()
|
return self.loader.recall()
|
||||||
else:
|
else:
|
||||||
raise ValueError('No loader specified')
|
raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -47,6 +47,8 @@ class TrainSplitConfig(BaseModel):
|
|||||||
|
|
||||||
class TrackerLogConfig(BaseModel):
|
class TrackerLogConfig(BaseModel):
|
||||||
log_type: str = 'console'
|
log_type: str = 'console'
|
||||||
|
resume: bool = False # For logs that are saved to unique locations, resume a previous run
|
||||||
|
auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@@ -59,6 +61,7 @@ class TrackerLogConfig(BaseModel):
|
|||||||
|
|
||||||
class TrackerLoadConfig(BaseModel):
|
class TrackerLoadConfig(BaseModel):
|
||||||
load_from: Optional[str] = None
|
load_from: Optional[str] = None
|
||||||
|
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|||||||
@@ -509,7 +509,6 @@ class DecoderTrainer(nn.Module):
|
|||||||
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||||
|
|
||||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||||
schedulers = list(self.accelerator.prepare(*schedulers))
|
|
||||||
|
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
|
|||||||
@@ -289,9 +289,9 @@ def train(
|
|||||||
sample = 0
|
sample = 0
|
||||||
samples_seen = 0
|
samples_seen = 0
|
||||||
val_sample = 0
|
val_sample = 0
|
||||||
step = lambda: int(trainer.step.item())
|
step = lambda: int(trainer.num_steps_taken(unet_number=1))
|
||||||
|
|
||||||
if tracker.loader is not None:
|
if tracker.can_recall:
|
||||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||||
if next_task == 'train':
|
if next_task == 'train':
|
||||||
sample = recalled_sample
|
sample = recalled_sample
|
||||||
|
|||||||
Reference in New Issue
Block a user