Compare commits

...

9 Commits

Author SHA1 Message Date
Phil Wang
11d4e11f10 allow for training unconditional ddpm or cascading ddpms 2022-05-15 16:54:56 -07:00
Phil Wang
99778e12de trainer classes now takes care of auto-casting numpy to torch tensors, and setting correct device based on model parameter devices 2022-05-15 15:25:45 -07:00
Phil Wang
0f0011caf0 todo 2022-05-15 14:28:35 -07:00
Phil Wang
7b7a62044a use eval vs training mode to determine whether to call backprop on trainer forward 2022-05-15 14:20:59 -07:00
Phil Wang
156fe5ed9f final cleanup for the day 2022-05-15 12:38:41 -07:00
Phil Wang
5ec34bebe1 cleanup readme 2022-05-15 12:29:26 -07:00
Phil Wang
8eaacf1ac1 remove indirection 2022-05-15 12:05:45 -07:00
Phil Wang
e66c7b0249 incorrect naming 2022-05-15 11:23:52 -07:00
Phil Wang
f7cd4a0992 product management 2022-05-15 11:21:12 -07:00
5 changed files with 226 additions and 212 deletions

104
README.md
View File

@@ -706,7 +706,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024) images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
``` ```
## Training wrapper (wip) ## Training wrapper
### Decoder Training ### Decoder Training
@@ -851,6 +851,57 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
``` ```
## Bonus
### Unconditional Training
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
ex.
```python
import torch
from dalle2_pytorch import Unet, Decoder
# unet for the cascading ddpm
unet1 = Unet(
dim = 128,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unets
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
timesteps = 1000,
unconditional = True
).cuda()
# mock images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder
for i in (1, 2):
loss = decoder(images, unet_number = i)
loss.backward()
# do the above for many many many many steps
# then it will learn to generate images
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
```
## Dataloaders
### Decoder Dataloaders ### Decoder Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network. In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
@@ -895,14 +946,14 @@ dataset = ImageEmbeddingDataset(
) )
``` ```
## Scripts ### Scripts (wip)
### Using the `train_diffusion_prior.py` script #### `train_diffusion_prior.py`
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process. This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below. Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
### Usage #### Usage
```bash ```bash
$ python train_diffusion_prior.py $ python train_diffusion_prior.py
@@ -910,58 +961,49 @@ $ python train_diffusion_prior.py
The most significant parameters for the script are as follows: The most significant parameters for the script are as follows:
--image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") - `image-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"`
--text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") - `text-embed-url`, default = `"https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"`
--image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates - `image-embed-dim`, default = `768` - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
--learning-rate, default=1.1e-4 - `learning-rate`, default = `1.1e-4`
--weight-decay, default=6.02e-2 - `weight-decay`, default = `6.02e-2`
--max-grad-norm, default=0.5 - `max-grad-norm`, default = `0.5`
--batch-size, default=10 ** 4 - `batch-size`, default = `10 ** 4`
--num-epochs, default=5 - `num-epochs`, default = `5`
--clip, default=None # Signals the prior to use pre-computed embeddings - `clip`, default = `None` # Signals the prior to use pre-computed embeddings
### Sample wandb run log #### Loading and Saving the DiffusionPrior model
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/1blxu24j
### Loading and saving the Diffusion Prior model
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory. Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model ```python
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
```
##### Loading
load_diffusion_model(dprior_path, device) load_diffusion_model(dprior_path, device)
dprior_path : path to saved model(.pth) dprior_path : path to saved model(.pth)
device : the cuda device you're running on device : the cuda device you're running on
##### Saving
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim) save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
save_path : path to save at save_path : path to save at
model : object of Diffusion_Prior model : object of Diffusion_Prior
optimizer : optimizer object - see train_diffusion_prior.py for how to create one. optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
scaler : a GradScaler object. scaler : a GradScaler object.
e.g: scaler = GradScaler(enabled=amp) e.g: scaler = GradScaler(enabled=amp)
config : config object created in train_diffusion_prior.py - see file for example. config : config object created in train_diffusion_prior.py - see file for example.
image_embed_dim - the dimension of the image_embedding image_embed_dim - the dimension of the image_embedding
e.g: 768 e.g: 768
## CLI (wip) ## CLI (wip)
@@ -1021,6 +1063,8 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well - [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
## Citations ## Citations

View File

@@ -1305,7 +1305,7 @@ class Unet(nn.Module):
self, self,
dim, dim,
*, *,
image_embed_dim, image_embed_dim = None,
text_embed_dim = None, text_embed_dim = None,
cond_dim = None, cond_dim = None,
num_image_tokens = 4, num_image_tokens = 4,
@@ -1377,7 +1377,7 @@ class Unet(nn.Module):
self.image_to_cond = nn.Sequential( self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens), nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens) Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity() ) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim) self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim) self.norm_mid_cond = nn.LayerNorm(cond_dim)
@@ -1701,7 +1701,7 @@ class Decoder(BaseGaussianDiffusion):
self.unconditional = unconditional self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None self.clip = None
if exists(clip): if exists(clip):
@@ -1988,8 +1988,7 @@ class Decoder(BaseGaussianDiffusion):
image_size = vae.get_encoded_fmap_size(image_size) image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size) shape = (batch_size, vae.encoded_dim, image_size, image_size)
if exists(lowres_cond_img): lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
lowres_cond_img = vae.encode(lowres_cond_img)
img = self.p_sample_loop( img = self.p_sample_loop(
unet, unet,
@@ -2037,12 +2036,12 @@ class Decoder(BaseGaussianDiffusion):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed): if not exists(image_embed) and not self.unconditional:
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init' assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image) image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None text_encodings = text_mask = None
if exists(text) and not exists(text_encodings): if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder' assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text) _, text_encodings, text_mask = self.clip.embed_text(text)
@@ -2063,9 +2062,7 @@ class Decoder(BaseGaussianDiffusion):
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
image = vae.encode(image) image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance) return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)

View File

@@ -1,7 +1,7 @@
import time import time
import copy import copy
from math import ceil from math import ceil
from functools import partial from functools import partial, wraps
from collections.abc import Iterable from collections.abc import Iterable
import torch import torch
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
import numpy as np
# helper functions # helper functions
def exists(val): def exists(val):
@@ -45,6 +47,29 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
# decorators
def cast_torch_tensor(fn):
@wraps(fn)
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
split_kwargs_index = len(all_args) - len(kwargs_keys)
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
out = fn(model, *args, **kwargs)
return out
return inner
# gradient accumulation functions # gradient accumulation functions
def split_iterable(it, split_size): def split_iterable(it, split_size):
@@ -80,13 +105,13 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
batch_size = len(first_tensor) batch_size = len(first_tensor)
split_size = default(split_size, batch_size) split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size) num_chunks = ceil(batch_size / split_size)
dict_len = len(kwargs) dict_len = len(kwargs)
dict_keys = kwargs.keys() dict_keys = kwargs.keys()
split_kwargs_index = len_all_args - dict_len split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args] split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0])) chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
@@ -254,10 +279,12 @@ class DiffusionPriorTrainer(nn.Module):
self.step += 1 self.step += 1
@torch.inference_mode() @torch.inference_mode()
@cast_torch_tensor
def p_sample_loop(self, *args, **kwargs): def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode() @torch.inference_mode()
@cast_torch_tensor
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@@ -265,6 +292,7 @@ class DiffusionPriorTrainer(nn.Module):
def sample_batch_size(self, *args, **kwargs): def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs) return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
@cast_torch_tensor
def forward( def forward(
self, self,
*args, *args,
@@ -279,7 +307,9 @@ class DiffusionPriorTrainer(nn.Module):
loss = loss * chunk_size_frac loss = loss * chunk_size_frac
total_loss += loss.item() total_loss += loss.item()
self.scaler.scale(loss).backward()
if self.training:
self.scaler.scale(loss).backward()
return total_loss return total_loss
@@ -375,6 +405,7 @@ class DecoderTrainer(nn.Module):
self.step += 1 self.step += 1
@torch.no_grad() @torch.no_grad()
@cast_torch_tensor
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
if self.use_ema: if self.use_ema:
trainable_unets = self.decoder.unets trainable_unets = self.decoder.unets
@@ -391,6 +422,7 @@ class DecoderTrainer(nn.Module):
return output return output
@cast_torch_tensor
def forward( def forward(
self, self,
*args, *args,
@@ -406,6 +438,8 @@ class DecoderTrainer(nn.Module):
loss = loss * chunk_size_frac loss = loss * chunk_size_frac
total_loss += loss.item() total_loss += loss.item()
self.scale(loss, unet_number = unet_number).backward()
if self.training:
self.scale(loss, unet_number = unet_number).backward()
return total_loss return total_loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.31', version = '0.2.35',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -30,6 +30,7 @@ setup(
'einops-exts>=0.0.3', 'einops-exts>=0.0.3',
'embedding-reader', 'embedding-reader',
'kornia>=0.5.4', 'kornia>=0.5.4',
'numpy',
'pillow', 'pillow',
'resize-right>=0.0.2', 'resize-right>=0.0.2',
'rotary-embedding-torch', 'rotary-embedding-torch',

View File

@@ -111,37 +111,110 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity), "CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)}) "Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
def train(image_embed_dim, @click.command()
image_embed_url, @click.option("--wandb-entity", default="laion")
text_embed_url, @click.option("--wandb-project", default="diffusion-prior")
batch_size, @click.option("--wandb-dataset", default="LAION-5B")
train_percent, @click.option("--wandb-arch", default="DiffusionPrior")
val_percent, @click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
test_percent, @click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
num_epochs, @click.option("--learning-rate", default=1.1e-4)
dp_loss_type, @click.option("--weight-decay", default=6.02e-2)
clip, @click.option("--dropout", default=5e-2)
dp_condition_on_text_encodings, @click.option("--max-grad-norm", default=0.5)
dp_timesteps, @click.option("--batch-size", default=10**4)
dp_normformer, @click.option("--num-epochs", default=5)
dp_cond_drop_prob, @click.option("--image-embed-dim", default=768)
dpn_depth, @click.option("--train-percent", default=0.7)
dpn_dim_head, @click.option("--val-percent", default=0.2)
dpn_heads, @click.option("--test-percent", default=0.1)
save_interval, @click.option("--dpn-depth", default=6)
save_path, @click.option("--dpn-dim-head", default=64)
device, @click.option("--dpn-heads", default=8)
RESUME, @click.option("--dp-condition-on-text-encodings", default=False)
DPRIOR_PATH, @click.option("--dp-timesteps", default=100)
config, @click.option("--dp-normformer", default=False)
wandb_entity, @click.option("--dp-cond-drop-prob", default=0.1)
wandb_project, @click.option("--dp-loss-type", default="l2")
learning_rate=0.001, @click.option("--clip", default=None)
max_grad_norm=0.5, @click.option("--amp", default=False)
weight_decay=0.01, @click.option("--save-interval", default=30)
dropout=0.05, @click.option("--save-path", default="./diffusion_prior_checkpoints")
amp=False): @click.option("--pretrained-model-path", default=None)
def train(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = args.pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Training loop
# diffusion prior network # diffusion prior network
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
@@ -269,140 +342,5 @@ def train(image_embed_dim,
end = num_data_points end = num_data_points
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test") eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
@click.command()
@click.option("--wandb-entity", default="laion")
@click.option("--wandb-project", default="diffusion-prior")
@click.option("--wandb-dataset", default="LAION-5B")
@click.option("--wandb-arch", default="DiffusionPrior")
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
@click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5)
@click.option("--batch-size", default=10**4)
@click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.7)
@click.option("--val-percent", default=0.2)
@click.option("--test-percent", default=0.1)
@click.option("--dpn-depth", default=6)
@click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=8)
@click.option("--dp-condition-on-text-encodings", default=False)
@click.option("--dp-timesteps", default=100)
@click.option("--dp-normformer", default=False)
@click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2")
@click.option("--clip", default=None)
@click.option("--amp", default=False)
@click.option("--save-interval", default=30)
@click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None)
def main(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = args.pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Training loop
train(image_embed_dim,
image_embed_url,
text_embed_url,
batch_size,
train_percent,
val_percent,
test_percent,
num_epochs,
dp_loss_type,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
dpn_heads,
save_interval,
save_path,
device,
RESUME,
DPRIOR_PATH,
config,
wandb_entity,
wandb_project,
learning_rate,
max_grad_norm,
weight_decay,
dropout,
amp)
if __name__ == "__main__": if __name__ == "__main__":
main() train()