mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix wandb logging in tracker, and do some cleanup
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,3 +1,6 @@
|
|||||||
|
# default experiment tracker data
|
||||||
|
.tracker-data/
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|||||||
@@ -1,20 +1,33 @@
|
|||||||
import os
|
import os
|
||||||
from itertools import zip_longest
|
from pathlib import Path
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import importlib
|
||||||
|
from itertools import zip_longest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
# constants
|
||||||
|
|
||||||
|
DEFAULT_DATA_PATH = './.tracker-data'
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
def import_or_print_error(pkg_name, err_str = None):
|
||||||
try:
|
try:
|
||||||
import wandb
|
return importlib.import_module(pkg_name)
|
||||||
except ImportError as e:
|
except ModuleNotFoundError as e:
|
||||||
print('`pip install wandb` to use the wandb recall function')
|
if exists(err_str):
|
||||||
raise e
|
print(err_str)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# load state dict functions
|
||||||
|
|
||||||
|
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
||||||
|
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||||
file_reference = wandb.restore(file_path, run_path=run_path)
|
file_reference = wandb.restore(file_path, run_path=run_path)
|
||||||
return torch.load(file_reference.name)
|
return torch.load(file_reference.name)
|
||||||
|
|
||||||
@@ -24,11 +37,10 @@ def load_local_state_dict(file_path, **kwargs):
|
|||||||
# base class
|
# base class
|
||||||
|
|
||||||
class BaseTracker(nn.Module):
|
class BaseTracker(nn.Module):
|
||||||
def __init__(self, data_path):
|
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert data_path is not None, "Tracker must have a data_path to save local content"
|
self.data_path = Path(data_path)
|
||||||
self.data_path = os.path.abspath(data_path)
|
self.data_path.mkdir(parents = True, exist_ok = True)
|
||||||
os.makedirs(self.data_path, exist_ok=True)
|
|
||||||
|
|
||||||
def init(self, config, **kwargs):
|
def init(self, config, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -66,28 +78,19 @@ class ConsoleTracker(BaseTracker):
|
|||||||
def log(self, log, **kwargs):
|
def log(self, log, **kwargs):
|
||||||
print(log)
|
print(log)
|
||||||
|
|
||||||
def log_images(self, images, **kwargs):
|
def log_images(self, images, **kwargs): # noop for logging images
|
||||||
"""
|
|
||||||
Currently, do nothing with console logged images
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||||
torch.save(state_dict, os.path.join(self.data_path, relative_path))
|
torch.save(state_dict, str(self.data_path / relative_path))
|
||||||
|
|
||||||
# basic wandb class
|
# basic wandb class
|
||||||
|
|
||||||
class WandbTracker(BaseTracker):
|
class WandbTracker(BaseTracker):
|
||||||
def __init__(self, data_path):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(data_path)
|
super().__init__(*args, **kwargs)
|
||||||
try:
|
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
||||||
import wandb
|
|
||||||
except ImportError as e:
|
|
||||||
print('`pip install wandb` to use the wandb experiment tracker')
|
|
||||||
raise e
|
|
||||||
|
|
||||||
os.environ["WANDB_SILENT"] = "true"
|
os.environ["WANDB_SILENT"] = "true"
|
||||||
self.wandb = wandb
|
|
||||||
|
|
||||||
def init(self, **config):
|
def init(self, **config):
|
||||||
self.wandb.init(**config)
|
self.wandb.init(**config)
|
||||||
@@ -102,12 +105,12 @@ class WandbTracker(BaseTracker):
|
|||||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||||
"""
|
"""
|
||||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||||
self.log({ image_section: wandb_images }, **kwargs)
|
self.wandb.log({ image_section: wandb_images }, **kwargs)
|
||||||
|
|
||||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||||
"""
|
"""
|
||||||
Saves a state_dict to disk and uploads it
|
Saves a state_dict to disk and uploads it
|
||||||
"""
|
"""
|
||||||
full_path = os.path.join(self.data_path, relative_path)
|
full_path = str(self.data_path / relative_path)
|
||||||
torch.save(state_dict, full_path)
|
torch.save(state_dict, full_path)
|
||||||
self.wandb.save(full_path, base_path=self.data_path) # Upload and keep relative to data_path
|
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
||||||
|
|||||||
Reference in New Issue
Block a user