diff --git a/README.md b/README.md
index d3ec4cb..06c5755 100644
--- a/README.md
+++ b/README.md
@@ -1007,6 +1007,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
+- [x] use an experimental tracker agnostic setup, as done here
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
@@ -1014,7 +1015,6 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
-- [ ] use an experimental tracker agnostic setup, as done here
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
diff --git a/dalle2_pytorch/trackers.py b/dalle2_pytorch/trackers.py
new file mode 100644
index 0000000..fe07a9f
--- /dev/null
+++ b/dalle2_pytorch/trackers.py
@@ -0,0 +1,49 @@
+import os
+import torch
+from torch import nn
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+# base class
+
+class BaseTracker(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def init(self, config, **kwargs):
+ raise NotImplementedError
+
+ def log(self, log, **kwargs):
+ raise NotImplementedError
+
+# basic stdout class
+
+class ConsoleTracker(BaseTracker):
+ def init(self, **config):
+ print(config)
+
+ def log(self, log, **kwargs):
+ print(log)
+
+# basic wandb class
+
+class WandbTracker(BaseTracker):
+ def __init__(self):
+ super().__init__()
+ try:
+ import wandb
+ except ImportError as e:
+ print('`pip install wandb` to use the wandb experiment tracker')
+ raise e
+
+ os.environ["WANDB_SILENT"] = "true"
+ self.wandb = wandb
+
+ def init(self, **config):
+ self.wandb.init(**config)
+
+ def log(self, log, **kwargs):
+ self.wandb.log(log, **kwargs)
diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py
index 37c710a..768e72d 100644
--- a/dalle2_pytorch/train.py
+++ b/dalle2_pytorch/train.py
@@ -228,6 +228,8 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
+ self.register_buffer('step', torch.tensor([0.]))
+
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
@@ -240,6 +242,8 @@ class DiffusionPriorTrainer(nn.Module):
if self.use_ema:
self.ema_diffusion_prior.update()
+ self.step += 1
+
@torch.inference_mode()
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@@ -328,6 +332,8 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
+ self.register_buffer('step', torch.tensor([0.]))
+
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
@@ -358,6 +364,8 @@ class DecoderTrainer(nn.Module):
ema_unet = self.ema_unets[index]
ema_unet.update()
+ self.step += 1
+
@torch.no_grad()
def sample(self, *args, **kwargs):
if self.use_ema:
diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py
index 016b4e1..f73e0ac 100644
--- a/train_diffusion_prior.py
+++ b/train_diffusion_prior.py
@@ -1,24 +1,26 @@
import os
import math
+import time
import argparse
import numpy as np
import torch
from torch import nn
-from embedding_reader import EmbeddingReader
+from torch.cuda.amp import autocast, GradScaler
+
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.optimizer import get_optimizer
-from torch.cuda.amp import autocast,GradScaler
+from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
+
+from embedding_reader import EmbeddingReader
-import time
from tqdm import tqdm
-import wandb
-os.environ["WANDB_SILENT"] = "true"
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
+tracker = WandbTracker()
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval()
@@ -40,7 +42,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
total_samples += batches
avg_loss = (total_loss / total_samples)
- wandb.log({f'{phase} {loss_type}': avg_loss})
+ tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
diffusion_prior.eval()
@@ -87,7 +89,7 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N
text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
- wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
+ tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
@@ -201,7 +203,7 @@ def train(image_embed_dim,
image_embed_dim)
# Log to wandb
- wandb.log({"Training loss": loss.item(),
+ tracker.log({"Training loss": loss.item(),
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
@@ -306,7 +308,7 @@ def main():
if(DPRIOR_PATH is not None):
RESUME = True
else:
- wandb.init(
+ tracker.init(
entity=args.wandb_entity,
project=args.wandb_project,
config=config)
@@ -351,4 +353,4 @@ def main():
args.amp)
if __name__ == "__main__":
- main()
+ main()