experiment tracker agnostic

This commit is contained in:
Phil Wang
2022-05-15 09:56:40 -07:00
parent 4ec6d0ba81
commit 89de5af63e
4 changed files with 70 additions and 11 deletions

View File

@@ -1007,6 +1007,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images - [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14 - [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option - [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
@@ -1014,7 +1015,6 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824 - [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes - [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes

View File

@@ -0,0 +1,49 @@
import os
import torch
from torch import nn
# helper functions
def exists(val):
return val is not None
# base class
class BaseTracker(nn.Module):
def __init__(self):
super().__init__()
def init(self, config, **kwargs):
raise NotImplementedError
def log(self, log, **kwargs):
raise NotImplementedError
# basic stdout class
class ConsoleTracker(BaseTracker):
def init(self, **config):
print(config)
def log(self, log, **kwargs):
print(log)
# basic wandb class
class WandbTracker(BaseTracker):
def __init__(self):
super().__init__()
try:
import wandb
except ImportError as e:
print('`pip install wandb` to use the wandb experiment tracker')
raise e
os.environ["WANDB_SILENT"] = "true"
self.wandb = wandb
def init(self, **config):
self.wandb.init(**config)
def log(self, log, **kwargs):
self.wandb.log(log, **kwargs)

View File

@@ -228,6 +228,8 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
def update(self): def update(self):
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
@@ -240,6 +242,8 @@ class DiffusionPriorTrainer(nn.Module):
if self.use_ema: if self.use_ema:
self.ema_diffusion_prior.update() self.ema_diffusion_prior.update()
self.step += 1
@torch.inference_mode() @torch.inference_mode()
def p_sample_loop(self, *args, **kwargs): def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@@ -328,6 +332,8 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
@property @property
def unets(self): def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
@@ -358,6 +364,8 @@ class DecoderTrainer(nn.Module):
ema_unet = self.ema_unets[index] ema_unet = self.ema_unets[index]
ema_unet.update() ema_unet.update()
self.step += 1
@torch.no_grad() @torch.no_grad()
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
if self.use_ema: if self.use_ema:

View File

@@ -1,24 +1,26 @@
import os import os
import math import math
import time
import argparse import argparse
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from embedding_reader import EmbeddingReader from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from embedding_reader import EmbeddingReader
import time
from tqdm import tqdm from tqdm import tqdm
import wandb
os.environ["WANDB_SILENT"] = "true"
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
tracker = WandbTracker()
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval() model.eval()
@@ -40,7 +42,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
total_samples += batches total_samples += batches
avg_loss = (total_loss / total_samples) avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss}) tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device): def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
diffusion_prior.eval() diffusion_prior.eval()
@@ -87,7 +89,7 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N
text_embed, predicted_unrelated_embeddings).cpu().numpy() text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos( predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy() test_image_embeddings, predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity), tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity), "CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity), "CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity), "CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
@@ -201,7 +203,7 @@ def train(image_embed_dim,
image_embed_dim) image_embed_dim)
# Log to wandb # Log to wandb
wandb.log({"Training loss": loss.item(), tracker.log({"Training loss": loss.item(),
"Steps": step, "Steps": step,
"Samples per second": samples_per_sec}) "Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed) # Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
@@ -306,7 +308,7 @@ def main():
if(DPRIOR_PATH is not None): if(DPRIOR_PATH is not None):
RESUME = True RESUME = True
else: else:
wandb.init( tracker.init(
entity=args.wandb_entity, entity=args.wandb_entity,
project=args.wandb_project, project=args.wandb_project,
config=config) config=config)
@@ -351,4 +353,4 @@ def main():
args.amp) args.amp)
if __name__ == "__main__": if __name__ == "__main__":
main() main()