Compare commits

...

9 Commits

5 changed files with 357 additions and 175 deletions

View File

@@ -732,8 +732,8 @@ clip = CLIP(
# mock data # mock data
text = torch.randint(0, 49408, (4, 256)).cuda() text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda() images = torch.randn(32, 3, 256, 256).cuda()
# decoder (with unet) # decoder (with unet)
@@ -774,7 +774,12 @@ decoder_trainer = DecoderTrainer(
) )
for unet_number in (1, 2): for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward loss = decoder_trainer(
images,
text = text,
unet_number = unet_number, # which unet to train on
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
)
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
@@ -809,8 +814,8 @@ clip = CLIP(
# mock data # mock data
text = torch.randint(0, 49408, (4, 256)).cuda() text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda() images = torch.randn(32, 3, 256, 256).cuda()
# prior networks (with transformer) # prior networks (with transformer)
@@ -837,7 +842,7 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
ema_update_every = 10, ema_update_every = 10,
) )
loss = diffusion_prior_trainer(text, images) loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop # after much of the above three lines in a loop
@@ -1002,6 +1007,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images - [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14 - [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option - [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
@@ -1009,7 +1015,6 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824 - [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes - [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes

View File

@@ -0,0 +1,49 @@
import os
import torch
from torch import nn
# helper functions
def exists(val):
return val is not None
# base class
class BaseTracker(nn.Module):
def __init__(self):
super().__init__()
def init(self, config, **kwargs):
raise NotImplementedError
def log(self, log, **kwargs):
raise NotImplementedError
# basic stdout class
class ConsoleTracker(BaseTracker):
def init(self, **config):
print(config)
def log(self, log, **kwargs):
print(log)
# basic wandb class
class WandbTracker(BaseTracker):
def __init__(self):
super().__init__()
try:
import wandb
except ImportError as e:
print('`pip install wandb` to use the wandb experiment tracker')
raise e
os.environ["WANDB_SILENT"] = "true"
self.wandb = wandb
def init(self, **config):
self.wandb.init(**config)
def log(self, log, **kwargs):
self.wandb.log(log, **kwargs)

View File

@@ -1,6 +1,8 @@
import time import time
import copy import copy
from math import ceil
from functools import partial from functools import partial
from collections.abc import Iterable
import torch import torch
from torch import nn from torch import nn
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
def exists(val): def exists(val):
return val is not None return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1): def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length) return val if isinstance(val, tuple) else ((val,) * length)
@@ -40,6 +45,56 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
# gradient accumulation functions
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def find_first(cond, arr):
for el in arr:
if cond(el):
return el
return None
def split_args_and_kwargs(*args, split_size = None, **kwargs):
all_args = (*args, *kwargs.values())
len_all_args = len(all_args)
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
assert exists(first_tensor)
batch_size = len(first_tensor)
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers # print helpers
def print_ribbon(s, symbol = '=', repeat = 40): def print_ribbon(s, symbol = '=', repeat = 40):
@@ -71,7 +126,7 @@ def load_diffusion_model(dprior_path, device):
# Load state dict from saved model # Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model']) diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior return diffusion_prior, loaded_obj
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim): def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict # Saving State Dict
@@ -182,6 +237,8 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
def update(self): def update(self):
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
@@ -194,6 +251,8 @@ class DiffusionPriorTrainer(nn.Module):
if self.use_ema: if self.use_ema:
self.ema_diffusion_prior.update() self.ema_diffusion_prior.update()
self.step += 1
@torch.inference_mode() @torch.inference_mode()
def p_sample_loop(self, *args, **kwargs): def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@@ -209,14 +268,20 @@ class DiffusionPriorTrainer(nn.Module):
def forward( def forward(
self, self,
*args, *args,
divisor = 1, max_batch_size = None,
**kwargs **kwargs
): ):
with autocast(enabled = self.amp): total_loss = 0.
loss = self.diffusion_prior(*args, **kwargs)
scaled_loss = self.scaler.scale(loss / divisor) for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
scaled_loss.backward() with autocast(enabled = self.amp):
return loss.item() loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scaler.scale(loss).backward()
return total_loss
# decoder trainer # decoder trainer
@@ -275,6 +340,8 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
@property @property
def unets(self): def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
@@ -305,6 +372,8 @@ class DecoderTrainer(nn.Module):
ema_unet = self.ema_unets[index] ema_unet = self.ema_unets[index]
ema_unet.update() ema_unet.update()
self.step += 1
@torch.no_grad() @torch.no_grad()
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
if self.use_ema: if self.use_ema:
@@ -324,14 +393,19 @@ class DecoderTrainer(nn.Module):
def forward( def forward(
self, self,
x, *args,
*,
unet_number, unet_number,
divisor = 1, max_batch_size = None,
**kwargs **kwargs
): ):
with autocast(enabled = self.amp): total_loss = 0.
loss = self.decoder(x, unet_number = unet_number, **kwargs)
scaled_loss = self.scale(loss / divisor, unet_number = unet_number) for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
scaled_loss.backward() with autocast(enabled = self.amp):
return loss.item() loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scale(loss, unet_number = unet_number).backward()
return total_loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.24', version = '0.2.31',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -1,24 +1,42 @@
import os from pathlib import Path
import click
import math import math
import argparse import time
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler
import time from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from embedding_reader import EmbeddingReader
from tqdm import tqdm from tqdm import tqdm
import wandb # constants
os.environ["WANDB_SILENT"] = "true"
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
tracker = WandbTracker()
# helpers functions
def exists(val):
val is not None
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time
# functions
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval() model.eval()
@@ -40,7 +58,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
total_samples += batches total_samples += batches
avg_loss = (total_loss / total_samples) avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss}) tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device): def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
diffusion_prior.eval() diffusion_prior.eval()
@@ -87,7 +105,7 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N
text_embed, predicted_unrelated_embeddings).cpu().numpy() text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos( predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy() test_image_embeddings, predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity), tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity), "CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity), "CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity), "CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
@@ -124,48 +142,67 @@ def train(image_embed_dim,
dropout=0.05, dropout=0.05,
amp=False): amp=False):
# DiffusionPriorNetwork # diffusion prior network
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = image_embed_dim, dim = image_embed_dim,
depth = dpn_depth, depth = dpn_depth,
dim_head = dpn_dim_head, dim_head = dpn_dim_head,
heads = dpn_heads, heads = dpn_heads,
attn_dropout = dropout, attn_dropout = dropout,
ff_dropout = dropout, ff_dropout = dropout,
normformer = dp_normformer).to(device) normformer = dp_normformer
)
# DiffusionPrior with text embeddings and image embeddings pre-computed # diffusion prior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior( diffusion_prior = DiffusionPrior(
net = prior_network, net = prior_network,
clip = clip, clip = clip,
image_embed_dim = image_embed_dim, image_embed_dim = image_embed_dim,
timesteps = dp_timesteps, timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob, cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type, loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings).to(device) condition_on_text_encodings = dp_condition_on_text_encodings
)
# Load pre-trained model from DPRIOR_PATH # Load pre-trained model from DPRIOR_PATH
if RESUME: if RESUME:
diffusion_prior=load_diffusion_model(DPRIOR_PATH,device) diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
wandb.init( entity=wandb_entity, project=wandb_project, config=config) tracker.init(entity = wandb_entity, project = wandb_project, config = config)
# diffusion prior trainer
trainer = DiffusionPriorTrainer(
diffusion_prior = diffusion_prior,
lr = learning_rate,
wd = weight_decay,
max_grad_norm = max_grad_norm,
amp = amp,
).to(device)
# load optimizer and scaler
if RESUME:
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
trainer.scaler.load_state_dict(loaded_obj['scaler'])
# Create save_path if it doesn't exist # Create save_path if it doesn't exist
if not os.path.exists(save_path):
os.makedirs(save_path) Path(save_path).mkdir(exist_ok = True, parents = True)
# Get image and text embeddings from the servers # Get image and text embeddings from the servers
print_ribbon("Downloading embeddings - image and text") print_ribbon("Downloading embeddings - image and text")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy") image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy") text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count num_data_points = text_reader.count
### Training code ### ### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs
step = 0 timer = Timer()
t = time.time() epochs = num_epochs
train_set_size = int(train_percent*num_data_points) train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points) val_set_size = int(val_percent*num_data_points)
@@ -176,32 +213,31 @@ def train(image_embed_dim,
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size), for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)): text_reader(batch_size=batch_size, start=0, end=train_set_size)):
diffusion_prior.train() trainer.train()
emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device)
with autocast(enabled=amp): loss = trainer(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
scaler.scale(loss).backward()
# Samples per second # Samples per second
step+=1
samples_per_sec = batch_size*step/(time.time()-t) samples_per_sec = batch_size * step / timer.elapsed()
# Save checkpoint every save_interval minutes # Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval): if(int(timer.elapsed()) >= 60 * save_interval):
t = time.time() timer.reset()
save_diffusion_model( save_diffusion_model(
save_path, save_path,
diffusion_prior, diffusion_prior,
optimizer, trainer.optimizer,
scaler, trainer.scaler,
config, config,
image_embed_dim) image_embed_dim)
# Log to wandb # Log to wandb
wandb.log({"Training loss": loss.item(), tracker.log({"Training loss": loss.item(),
"Steps": step, "Steps": step,
"Samples per second": samples_per_sec}) "Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed) # Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
@@ -225,91 +261,109 @@ def train(image_embed_dim,
dp_loss_type, dp_loss_type,
phase="Validation") phase="Validation")
scaler.unscale_(optimizer) trainer.update()
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
### Test run ### ### Test run ###
test_set_size = int(test_percent*train_set_size) test_set_size = int(test_percent*train_set_size)
start=train_set_size+val_set_size start = train_set_size+val_set_size
end=num_data_points end = num_data_points
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test") eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
def main(): @click.command()
parser = argparse.ArgumentParser() @click.option("--wandb-entity", default="laion")
# Logging @click.option("--wandb-project", default="diffusion-prior")
parser.add_argument("--wandb-entity", type=str, default="laion") @click.option("--wandb-dataset", default="LAION-5B")
parser.add_argument("--wandb-project", type=str, default="diffusion-prior") @click.option("--wandb-arch", default="DiffusionPrior")
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B") @click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior") @click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
# URLs for embeddings @click.option("--learning-rate", default=1.1e-4)
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") @click.option("--weight-decay", default=6.02e-2)
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") @click.option("--dropout", default=5e-2)
# Hyperparameters @click.option("--max-grad-norm", default=0.5)
parser.add_argument("--learning-rate", type=float, default=1.1e-4) @click.option("--batch-size", default=10**4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2) @click.option("--num-epochs", default=5)
parser.add_argument("--dropout", type=float, default=5e-2) @click.option("--image-embed-dim", default=768)
parser.add_argument("--max-grad-norm", type=float, default=0.5) @click.option("--train-percent", default=0.7)
parser.add_argument("--batch-size", type=int, default=10**4) @click.option("--val-percent", default=0.2)
parser.add_argument("--num-epochs", type=int, default=5) @click.option("--test-percent", default=0.1)
# Image embed dimension @click.option("--dpn-depth", default=6)
parser.add_argument("--image-embed-dim", type=int, default=768) @click.option("--dpn-dim-head", default=64)
# Train-test split @click.option("--dpn-heads", default=8)
parser.add_argument("--train-percent", type=float, default=0.7) @click.option("--dp-condition-on-text-encodings", default=False)
parser.add_argument("--val-percent", type=float, default=0.2) @click.option("--dp-timesteps", default=100)
parser.add_argument("--test-percent", type=float, default=0.1) @click.option("--dp-normformer", default=False)
# LAION training(pre-computed embeddings) @click.option("--dp-cond-drop-prob", default=0.1)
# DiffusionPriorNetwork(dpn) parameters @click.option("--dp-loss-type", default="l2")
parser.add_argument("--dpn-depth", type=int, default=6) @click.option("--clip", default=None)
parser.add_argument("--dpn-dim-head", type=int, default=64) @click.option("--amp", default=False)
parser.add_argument("--dpn-heads", type=int, default=8) @click.option("--save-interval", default=30)
# DiffusionPrior(dp) parameters @click.option("--save-path", default="./diffusion_prior_checkpoints")
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) @click.option("--pretrained-model-path", default=None)
parser.add_argument("--dp-timesteps", type=int, default=100) def main(
parser.add_argument("--dp-normformer", type=bool, default=False) wandb_entity,
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1) wandb_project,
parser.add_argument("--dp-loss-type", type=str, default="l2") wandb_dataset,
parser.add_argument("--clip", type=str, default=None) wandb_arch,
parser.add_argument("--amp", type=bool, default=False) image_embed_url,
# Model checkpointing interval(minutes) text_embed_url,
parser.add_argument("--save-interval", type=int, default=30) learning_rate,
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints") weight_decay,
# Saved model path dropout,
parser.add_argument("--pretrained-model-path", type=str, default=None) max_grad_norm,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
args = parser.parse_args()
config = ({"learning_rate": args.learning_rate,
"architecture": args.wandb_arch,
"dataset": args.wandb_dataset,
"weight_decay":args.weight_decay,
"max_gradient_clipping_norm":args.max_grad_norm,
"batch_size":args.batch_size,
"epochs": args.num_epochs,
"diffusion_prior_network":{"depth":args.dpn_depth,
"dim_head":args.dpn_dim_head,
"heads":args.dpn_heads,
"normformer":args.dp_normformer},
"diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings,
"timesteps": args.dp_timesteps,
"cond_drop_prob":args.dp_cond_drop_prob,
"loss_type":args.dp_loss_type,
"clip":args.clip}
})
RESUME = False
# Check if DPRIOR_PATH exists(saved model path) # Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = args.pretrained_model_path DPRIOR_PATH = args.pretrained_model_path
if(DPRIOR_PATH is not None): RESUME = exists(DPRIOR_PATH)
RESUME = True
else: if not RESUME:
wandb.init( tracker.init(
entity=args.wandb_entity, entity = wandb_entity,
project=args.wandb_project, project = wandb_project,
config=config) config = config
)
# Obtain the utilized device. # Obtain the utilized device.
@@ -319,36 +373,36 @@ def main():
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Training loop # Training loop
train(args.image_embed_dim, train(image_embed_dim,
args.image_embed_url, image_embed_url,
args.text_embed_url, text_embed_url,
args.batch_size, batch_size,
args.train_percent, train_percent,
args.val_percent, val_percent,
args.test_percent, test_percent,
args.num_epochs, num_epochs,
args.dp_loss_type, dp_loss_type,
args.clip, clip,
args.dp_condition_on_text_encodings, dp_condition_on_text_encodings,
args.dp_timesteps, dp_timesteps,
args.dp_normformer, dp_normformer,
args.dp_cond_drop_prob, dp_cond_drop_prob,
args.dpn_depth, dpn_depth,
args.dpn_dim_head, dpn_dim_head,
args.dpn_heads, dpn_heads,
args.save_interval, save_interval,
args.save_path, save_path,
device, device,
RESUME, RESUME,
DPRIOR_PATH, DPRIOR_PATH,
config, config,
args.wandb_entity, wandb_entity,
args.wandb_project, wandb_project,
args.learning_rate, learning_rate,
args.max_grad_norm, max_grad_norm,
args.weight_decay, weight_decay,
args.dropout, dropout,
args.amp) amp)
if __name__ == "__main__": if __name__ == "__main__":
main() main()