mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Added a base_path parameter to all trackers for storing any local information they need to
113 lines
3.6 KiB
Python
113 lines
3.6 KiB
Python
import os
|
|
from itertools import zip_longest
|
|
from enum import Enum
|
|
import torch
|
|
from torch import nn
|
|
|
|
# helper functions
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
|
try:
|
|
import wandb
|
|
except ImportError as e:
|
|
print('`pip install wandb` to use the wandb recall function')
|
|
raise e
|
|
file_reference = wandb.restore(file_path, run_path=run_path)
|
|
return torch.load(file_reference.name)
|
|
|
|
def load_local_state_dict(file_path, **kwargs):
|
|
return torch.load(file_path)
|
|
|
|
# base class
|
|
|
|
class BaseTracker(nn.Module):
|
|
def __init__(self, data_path):
|
|
super().__init__()
|
|
assert data_path is not None, "Tracker must have a data_path to save local content"
|
|
self.data_path = os.path.abspath(data_path)
|
|
os.makedirs(self.data_path, exist_ok=True)
|
|
|
|
def init(self, config, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def log(self, log, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def log_images(self, images, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def recall_state_dict(self, recall_source, *args, **kwargs):
|
|
"""
|
|
Loads a state dict from any source.
|
|
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
|
|
this should not be linked to any individual tracker.
|
|
"""
|
|
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
|
|
if recall_source == 'wandb':
|
|
return load_wandb_state_dict(*args, **kwargs)
|
|
elif recall_source == 'local':
|
|
return load_local_state_dict(*args, **kwargs)
|
|
else:
|
|
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
|
|
|
|
|
# basic stdout class
|
|
|
|
class ConsoleTracker(BaseTracker):
|
|
def init(self, **config):
|
|
print(config)
|
|
|
|
def log(self, log, **kwargs):
|
|
print(log)
|
|
|
|
def log_images(self, images, **kwargs):
|
|
"""
|
|
Currently, do nothing with console logged images
|
|
"""
|
|
pass
|
|
|
|
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
|
torch.save(state_dict, os.path.join(self.data_path, relative_path))
|
|
|
|
# basic wandb class
|
|
|
|
class WandbTracker(BaseTracker):
|
|
def __init__(self, data_path):
|
|
super().__init__(data_path)
|
|
try:
|
|
import wandb
|
|
except ImportError as e:
|
|
print('`pip install wandb` to use the wandb experiment tracker')
|
|
raise e
|
|
|
|
os.environ["WANDB_SILENT"] = "true"
|
|
self.wandb = wandb
|
|
|
|
def init(self, **config):
|
|
self.wandb.init(**config)
|
|
|
|
def log(self, log, verbose=False, **kwargs):
|
|
if verbose:
|
|
print(log)
|
|
self.wandb.log(log, **kwargs)
|
|
|
|
def log_images(self, images, captions=[], image_section="images", **kwargs):
|
|
"""
|
|
Takes a tensor of images and a list of captions and logs them to wandb.
|
|
"""
|
|
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
|
self.log({ image_section: wandb_images }, **kwargs)
|
|
|
|
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
|
"""
|
|
Saves a state_dict to disk and uploads it
|
|
"""
|
|
full_path = os.path.join(self.data_path, relative_path)
|
|
torch.save(state_dict, full_path)
|
|
self.wandb.save(full_path, base_path=self.data_path) # Upload and keep relative to data_path |