diff --git a/configs/README.md b/configs/README.md index 158a9dc..e4fb77d 100644 --- a/configs/README.md +++ b/configs/README.md @@ -30,6 +30,7 @@ Defines the configuration options for the decoder model. The unets defined above | `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. | | `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. | | `learned_variance` | No | `True` | Whether to learn the variance. | +| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. | Any parameter from the `Decoder` constructor can also be given here. @@ -39,7 +40,8 @@ Settings for creation of the dataloaders. | Option | Required | Default | Description | | ------ | -------- | ------- | ----------- | | `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. | -| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. | +| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. | +| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. | | `num_workers` | No | `4` | The number of workers used in the dataloader. | | `batch_size` | No | `64` | The batch size. | | `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. | @@ -106,6 +108,13 @@ Tracking is split up into three sections: **Logging:** +All loggers have the following keys: +| Option | Required | Default | Description | +| ------ | -------- | ------- | ----------- | +| `log_type` | Yes | N/A | The type of logger class to use. | +| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. | +| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. | + If using `console` there is no further configuration than setting `log_type` to `console`. | Option | Required | Default | Description | | ------ | -------- | ------- | ----------- | @@ -119,10 +128,15 @@ If using `wandb` | `wandb_project` | Yes | N/A | The wandb project save the run to. | | `wandb_run_name` | No | `None` | The wandb run name. | | `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. | -| `wandb_resume` | No | `False` | Whether to resume an old run. | **Loading:** +All loaders have the following keys: +| Option | Required | Default | Description | +| ------ | -------- | ------- | ----------- | +| `load_from` | Yes | N/A | The type of loader class to use. | +| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. | + If using `local` | Option | Required | Default | Description | | ------ | -------- | ------- | ----------- | diff --git a/configs/train_decoder_config.example.json b/configs/train_decoder_config.example.json index 5e20c4a..cebdb02 100644 --- a/configs/train_decoder_config.example.json +++ b/configs/train_decoder_config.example.json @@ -20,7 +20,7 @@ }, "data": { "webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -", - "embeddings_url": "s3://bucket/embeddings/path/", + "img_embeddings_url": "s3://bucket/img_embeddings/path/", "num_workers": 4, "batch_size": 64, "start_shard": 0, diff --git a/dalle2_pytorch/trackers.py b/dalle2_pytorch/trackers.py index 517f0d7..2d0ba08 100644 --- a/dalle2_pytorch/trackers.py +++ b/dalle2_pytorch/trackers.py @@ -1,5 +1,6 @@ import urllib.request import os +import json from pathlib import Path import shutil from itertools import zip_longest @@ -37,14 +38,17 @@ class BaseLogger: data_path (str): A file path for storing temporary data. verbose (bool): Whether of not to always print logs to the console. """ - def __init__(self, data_path: str, verbose: bool = False, **kwargs): + def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs): self.data_path = Path(data_path) + self.resume = resume + self.auto_resume = auto_resume self.verbose = verbose def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None: """ Initializes the logger. Errors if the logger is invalid. + full_config is the config file dict while extra_config is anything else from the script that is not defined the config file. """ raise NotImplementedError @@ -60,6 +64,14 @@ class BaseLogger: def log_error(self, error_string, **kwargs) -> None: raise NotImplementedError + def get_resume_data(self, **kwargs) -> dict: + """ + Sets tracker attributes that along with { "resume": True } will be used to resume training. + It is assumed that after init is called this data will be complete. + If the logger does not have any resume functionality, it should return an empty dict. + """ + raise NotImplementedError + class ConsoleLogger(BaseLogger): def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None: print("Logging to console") @@ -76,6 +88,9 @@ class ConsoleLogger(BaseLogger): def log_error(self, error_string, **kwargs) -> None: print(error_string) + def get_resume_data(self, **kwargs) -> dict: + return {} + class WandbLogger(BaseLogger): """ Logs to a wandb run. @@ -85,7 +100,6 @@ class WandbLogger(BaseLogger): wandb_project (str): The wandb project to log to. wandb_run_id (str): The wandb run id to resume. wandb_run_name (str): The wandb run name to use. - wandb_resume (bool): Whether to resume a wandb run. """ def __init__(self, data_path: str, @@ -93,7 +107,6 @@ class WandbLogger(BaseLogger): wandb_project: str, wandb_run_id: Optional[str] = None, wandb_run_name: Optional[str] = None, - wandb_resume: bool = False, **kwargs ): super().__init__(data_path, **kwargs) @@ -101,7 +114,6 @@ class WandbLogger(BaseLogger): self.project = wandb_project self.run_id = wandb_run_id self.run_name = wandb_run_name - self.resume = wandb_resume def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None: assert self.entity is not None, "wandb_entity must be specified for wandb logger" @@ -149,6 +161,14 @@ class WandbLogger(BaseLogger): print(error_string) self.wandb.log({"error": error_string, **kwargs}, step=step) + def get_resume_data(self, **kwargs) -> dict: + # In order to resume, we need wandb_entity, wandb_project, and wandb_run_id + return { + "entity": self.entity, + "project": self.project, + "run_id": self.wandb.run.id + } + logger_type_map = { 'console': ConsoleLogger, 'wandb': WandbLogger, @@ -168,8 +188,9 @@ class BaseLoader: Parameters: data_path (str): A file path for storing temporary data. """ - def __init__(self, data_path: str, **kwargs): + def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs): self.data_path = Path(data_path) + self.only_auto_resume = only_auto_resume def init(self, logger: BaseLogger, **kwargs) -> None: raise NotImplementedError @@ -304,6 +325,10 @@ class LocalSaver(BaseSaver): def save_file(self, local_path: str, save_path: str, **kwargs) -> None: # Copy the file to save_path save_path_file_name = Path(save_path).name + # Make sure parent directory exists + save_path_parent = Path(save_path).parent + if not save_path_parent.exists(): + save_path_parent.mkdir(parents=True) print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}") shutil.copy(local_path, save_path) @@ -385,11 +410,7 @@ class Tracker: def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False): self.data_path = Path(data_path) if not dummy_mode: - if overwrite_data_path: - if self.data_path.exists(): - shutil.rmtree(self.data_path) - self.data_path.mkdir(parents=True) - else: + if not overwrite_data_path: assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.' if not self.data_path.exists(): self.data_path.mkdir(parents=True) @@ -398,7 +419,46 @@ class Tracker: self.savers: List[BaseSaver]= [] self.dummy_mode = dummy_mode + def _load_auto_resume(self) -> bool: + # If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run. + if not self.auto_resume_path.exists(): + if self.logger.auto_resume: + print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.") + return False + + # Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time + if not self.logger.auto_resume: + print(f'Removing auto_resume.json because auto_resume is not enabled in the config') + self.auto_resume_path.unlink() + return False + + # Otherwise we read the json into a dictionary will will override parts of logger.__dict__ + with open(self.auto_resume_path, 'r') as f: + auto_resume_dict = json.load(f) + # Check if the logger is of the same type as the autoresume save + if auto_resume_dict["logger_type"] != self.logger.__class__.__name__: + raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.') + # Then we are ready to override the logger with the autoresume save + self.logger.__dict__["resume"] = True + print(f"Updating {self.logger.__dict__} with {auto_resume_dict}") + self.logger.__dict__.update(auto_resume_dict) + return True + + def _save_auto_resume(self): + # Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file + auto_resume_dict = self.logger.get_resume_data() + auto_resume_dict['logger_type'] = self.logger.__class__.__name__ + with open(self.auto_resume_path, 'w') as f: + json.dump(auto_resume_dict, f) + def init(self, full_config: BaseModel, extra_config: dict): + self.auto_resume_path = self.data_path / 'auto_resume.json' + # Check for resuming the run + self.did_auto_resume = self._load_auto_resume() + if self.did_auto_resume: + print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n') + print(f"New logger config: {self.logger.__dict__}") + assert self.logger is not None, '`logger` must be set before `init` is called' if self.dummy_mode: # The only thing we need is a loader @@ -406,12 +466,17 @@ class Tracker: self.loader.init(self.logger) return assert len(self.savers) > 0, '`savers` must be set before `init` is called' + self.logger.init(full_config, extra_config) if self.loader is not None: self.loader.init(self.logger) for saver in self.savers: saver.init(self.logger) + if self.logger.auto_resume: + # Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved. + self._save_auto_resume() + def add_logger(self, logger: BaseLogger): self.logger = logger @@ -503,11 +568,16 @@ class Tracker: self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs) print(f'Error saving checkpoint: {e}') + @property + def can_recall(self): + # Defines whether a recall can be performed. + return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume) + def recall(self): - if self.loader is not None: + if self.can_recall: return self.loader.recall() else: - raise ValueError('No loader specified') + raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.') \ No newline at end of file diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index a016981..1bb7bfa 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -47,6 +47,8 @@ class TrainSplitConfig(BaseModel): class TrackerLogConfig(BaseModel): log_type: str = 'console' + resume: bool = False # For logs that are saved to unique locations, resume a previous run + auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed verbose: bool = False class Config: @@ -59,6 +61,7 @@ class TrackerLogConfig(BaseModel): class TrackerLoadConfig(BaseModel): load_from: Optional[str] = None + only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming class Config: extra = "allow" diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 14d9933..146057a 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -509,7 +509,6 @@ class DecoderTrainer(nn.Module): self.register_buffer('steps', torch.tensor([0] * self.num_unets)) decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) - schedulers = list(self.accelerator.prepare(*schedulers)) self.decoder = decoder diff --git a/train_decoder.py b/train_decoder.py index e99d0e7..6ab9050 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -289,9 +289,9 @@ def train( sample = 0 samples_seen = 0 val_sample = 0 - step = lambda: int(trainer.step.item()) + step = lambda: int(trainer.num_steps_taken(unet_number=1)) - if tracker.loader is not None: + if tracker.can_recall: start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer) if next_task == 'train': sample = recalled_sample