fix wandb logging in tracker, and do some cleanup

This commit is contained in:
Phil Wang
2022-05-20 17:27:43 -07:00
parent e0524a6aff
commit 721f9687c1
3 changed files with 34 additions and 28 deletions

3
.gitignore vendored
View File

@@ -1,3 +1,6 @@
# default experiment tracker data
.tracker-data/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

View File

@@ -1,20 +1,33 @@
import os import os
from itertools import zip_longest from pathlib import Path
from enum import Enum from enum import Enum
import importlib
from itertools import zip_longest
import torch import torch
from torch import nn from torch import nn
# constants
DEFAULT_DATA_PATH = './.tracker-data'
# helper functions # helper functions
def exists(val): def exists(val):
return val is not None return val is not None
def load_wandb_state_dict(run_path, file_path, **kwargs): def import_or_print_error(pkg_name, err_str = None):
try: try:
import wandb return importlib.import_module(pkg_name)
except ImportError as e: except ModuleNotFoundError as e:
print('`pip install wandb` to use the wandb recall function') if exists(err_str):
raise e print(err_str)
exit()
# load state dict functions
def load_wandb_state_dict(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path) file_reference = wandb.restore(file_path, run_path=run_path)
return torch.load(file_reference.name) return torch.load(file_reference.name)
@@ -24,11 +37,10 @@ def load_local_state_dict(file_path, **kwargs):
# base class # base class
class BaseTracker(nn.Module): class BaseTracker(nn.Module):
def __init__(self, data_path): def __init__(self, data_path = DEFAULT_DATA_PATH):
super().__init__() super().__init__()
assert data_path is not None, "Tracker must have a data_path to save local content" self.data_path = Path(data_path)
self.data_path = os.path.abspath(data_path) self.data_path.mkdir(parents = True, exist_ok = True)
os.makedirs(self.data_path, exist_ok=True)
def init(self, config, **kwargs): def init(self, config, **kwargs):
raise NotImplementedError raise NotImplementedError
@@ -66,28 +78,19 @@ class ConsoleTracker(BaseTracker):
def log(self, log, **kwargs): def log(self, log, **kwargs):
print(log) print(log)
def log_images(self, images, **kwargs): def log_images(self, images, **kwargs): # noop for logging images
"""
Currently, do nothing with console logged images
"""
pass pass
def save_state_dict(self, state_dict, relative_path, **kwargs): def save_state_dict(self, state_dict, relative_path, **kwargs):
torch.save(state_dict, os.path.join(self.data_path, relative_path)) torch.save(state_dict, str(self.data_path / relative_path))
# basic wandb class # basic wandb class
class WandbTracker(BaseTracker): class WandbTracker(BaseTracker):
def __init__(self, data_path): def __init__(self, *args, **kwargs):
super().__init__(data_path) super().__init__(*args, **kwargs)
try: self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
import wandb
except ImportError as e:
print('`pip install wandb` to use the wandb experiment tracker')
raise e
os.environ["WANDB_SILENT"] = "true" os.environ["WANDB_SILENT"] = "true"
self.wandb = wandb
def init(self, **config): def init(self, **config):
self.wandb.init(**config) self.wandb.init(**config)
@@ -102,12 +105,12 @@ class WandbTracker(BaseTracker):
Takes a tensor of images and a list of captions and logs them to wandb. Takes a tensor of images and a list of captions and logs them to wandb.
""" """
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)] wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
self.log({ image_section: wandb_images }, **kwargs) self.wandb.log({ image_section: wandb_images }, **kwargs)
def save_state_dict(self, state_dict, relative_path, **kwargs): def save_state_dict(self, state_dict, relative_path, **kwargs):
""" """
Saves a state_dict to disk and uploads it Saves a state_dict to disk and uploads it
""" """
full_path = os.path.join(self.data_path, relative_path) full_path = str(self.data_path / relative_path)
torch.save(state_dict, full_path) torch.save(state_dict, full_path)
self.wandb.save(full_path, base_path=self.data_path) # Upload and keep relative to data_path self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.3.3', version = '0.3.4',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',