From 721f9687c1be45e637ca709e1b5870324760f94b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 20 May 2022 17:27:43 -0700 Subject: [PATCH] fix wandb logging in tracker, and do some cleanup --- .gitignore | 3 ++ dalle2_pytorch/trackers.py | 57 ++++++++++++++++++++------------------ setup.py | 2 +- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index b6e4761..55301b1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# default experiment tracker data +.tracker-data/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/dalle2_pytorch/trackers.py b/dalle2_pytorch/trackers.py index 6225256..a5dbe84 100644 --- a/dalle2_pytorch/trackers.py +++ b/dalle2_pytorch/trackers.py @@ -1,20 +1,33 @@ import os -from itertools import zip_longest +from pathlib import Path from enum import Enum +import importlib +from itertools import zip_longest + import torch from torch import nn +# constants + +DEFAULT_DATA_PATH = './.tracker-data' + # helper functions def exists(val): return val is not None -def load_wandb_state_dict(run_path, file_path, **kwargs): +def import_or_print_error(pkg_name, err_str = None): try: - import wandb - except ImportError as e: - print('`pip install wandb` to use the wandb recall function') - raise e + return importlib.import_module(pkg_name) + except ModuleNotFoundError as e: + if exists(err_str): + print(err_str) + exit() + +# load state dict functions + +def load_wandb_state_dict(run_path, file_path, **kwargs): + wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function') file_reference = wandb.restore(file_path, run_path=run_path) return torch.load(file_reference.name) @@ -24,11 +37,10 @@ def load_local_state_dict(file_path, **kwargs): # base class class BaseTracker(nn.Module): - def __init__(self, data_path): + def __init__(self, data_path = DEFAULT_DATA_PATH): super().__init__() - assert data_path is not None, "Tracker must have a data_path to save local content" - self.data_path = os.path.abspath(data_path) - os.makedirs(self.data_path, exist_ok=True) + self.data_path = Path(data_path) + self.data_path.mkdir(parents = True, exist_ok = True) def init(self, config, **kwargs): raise NotImplementedError @@ -66,28 +78,19 @@ class ConsoleTracker(BaseTracker): def log(self, log, **kwargs): print(log) - def log_images(self, images, **kwargs): - """ - Currently, do nothing with console logged images - """ + def log_images(self, images, **kwargs): # noop for logging images pass def save_state_dict(self, state_dict, relative_path, **kwargs): - torch.save(state_dict, os.path.join(self.data_path, relative_path)) + torch.save(state_dict, str(self.data_path / relative_path)) # basic wandb class class WandbTracker(BaseTracker): - def __init__(self, data_path): - super().__init__(data_path) - try: - import wandb - except ImportError as e: - print('`pip install wandb` to use the wandb experiment tracker') - raise e - + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker') os.environ["WANDB_SILENT"] = "true" - self.wandb = wandb def init(self, **config): self.wandb.init(**config) @@ -102,12 +105,12 @@ class WandbTracker(BaseTracker): Takes a tensor of images and a list of captions and logs them to wandb. """ wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)] - self.log({ image_section: wandb_images }, **kwargs) + self.wandb.log({ image_section: wandb_images }, **kwargs) def save_state_dict(self, state_dict, relative_path, **kwargs): """ Saves a state_dict to disk and uploads it """ - full_path = os.path.join(self.data_path, relative_path) + full_path = str(self.data_path / relative_path) torch.save(state_dict, full_path) - self.wandb.save(full_path, base_path=self.data_path) # Upload and keep relative to data_path \ No newline at end of file + self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path diff --git a/setup.py b/setup.py index e3731a8..638b78a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.3.3', + version = '0.3.4', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',