From e0524a6affdfc5a05e5e0b9514e06fcfbc0ee788 Mon Sep 17 00:00:00 2001 From: Aidan Dempster Date: Fri, 20 May 2022 19:39:23 -0400 Subject: [PATCH] Implemented the wandb tracker (#106) Added a base_path parameter to all trackers for storing any local information they need to --- dalle2_pytorch/trackers.py | 72 +++++++++++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/dalle2_pytorch/trackers.py b/dalle2_pytorch/trackers.py index fe07a9f..6225256 100644 --- a/dalle2_pytorch/trackers.py +++ b/dalle2_pytorch/trackers.py @@ -1,4 +1,6 @@ import os +from itertools import zip_longest +from enum import Enum import torch from torch import nn @@ -7,11 +9,26 @@ from torch import nn def exists(val): return val is not None +def load_wandb_state_dict(run_path, file_path, **kwargs): + try: + import wandb + except ImportError as e: + print('`pip install wandb` to use the wandb recall function') + raise e + file_reference = wandb.restore(file_path, run_path=run_path) + return torch.load(file_reference.name) + +def load_local_state_dict(file_path, **kwargs): + return torch.load(file_path) + # base class class BaseTracker(nn.Module): - def __init__(self): + def __init__(self, data_path): super().__init__() + assert data_path is not None, "Tracker must have a data_path to save local content" + self.data_path = os.path.abspath(data_path) + os.makedirs(self.data_path, exist_ok=True) def init(self, config, **kwargs): raise NotImplementedError @@ -19,6 +36,27 @@ class BaseTracker(nn.Module): def log(self, log, **kwargs): raise NotImplementedError + def log_images(self, images, **kwargs): + raise NotImplementedError + + def save_state_dict(self, state_dict, relative_path, **kwargs): + raise NotImplementedError + + def recall_state_dict(self, recall_source, *args, **kwargs): + """ + Loads a state dict from any source. + Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk), + this should not be linked to any individual tracker. + """ + # TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement + if recall_source == 'wandb': + return load_wandb_state_dict(*args, **kwargs) + elif recall_source == 'local': + return load_local_state_dict(*args, **kwargs) + else: + raise ValueError('`recall_source` must be one of `wandb` or `local`') + + # basic stdout class class ConsoleTracker(BaseTracker): @@ -28,11 +66,20 @@ class ConsoleTracker(BaseTracker): def log(self, log, **kwargs): print(log) + def log_images(self, images, **kwargs): + """ + Currently, do nothing with console logged images + """ + pass + + def save_state_dict(self, state_dict, relative_path, **kwargs): + torch.save(state_dict, os.path.join(self.data_path, relative_path)) + # basic wandb class class WandbTracker(BaseTracker): - def __init__(self): - super().__init__() + def __init__(self, data_path): + super().__init__(data_path) try: import wandb except ImportError as e: @@ -45,5 +92,22 @@ class WandbTracker(BaseTracker): def init(self, **config): self.wandb.init(**config) - def log(self, log, **kwargs): + def log(self, log, verbose=False, **kwargs): + if verbose: + print(log) self.wandb.log(log, **kwargs) + + def log_images(self, images, captions=[], image_section="images", **kwargs): + """ + Takes a tensor of images and a list of captions and logs them to wandb. + """ + wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)] + self.log({ image_section: wandb_images }, **kwargs) + + def save_state_dict(self, state_dict, relative_path, **kwargs): + """ + Saves a state_dict to disk and uploads it + """ + full_path = os.path.join(self.data_path, relative_path) + torch.save(state_dict, full_path) + self.wandb.save(full_path, base_path=self.data_path) # Upload and keep relative to data_path \ No newline at end of file