mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
experiment tracker agnostic
This commit is contained in:
49
dalle2_pytorch/trackers.py
Normal file
49
dalle2_pytorch/trackers.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# base class
|
||||
|
||||
class BaseTracker(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
# basic stdout class
|
||||
|
||||
class ConsoleTracker(BaseTracker):
|
||||
def init(self, **config):
|
||||
print(config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
print(log)
|
||||
|
||||
# basic wandb class
|
||||
|
||||
class WandbTracker(BaseTracker):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
import wandb
|
||||
except ImportError as e:
|
||||
print('`pip install wandb` to use the wandb experiment tracker')
|
||||
raise e
|
||||
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
self.wandb = wandb
|
||||
|
||||
def init(self, **config):
|
||||
self.wandb.init(**config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
self.wandb.log(log, **kwargs)
|
||||
Reference in New Issue
Block a user