diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 77575d1..b8d3b2e 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -111,37 +111,110 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N "CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity), "Cosine similarity difference":np.mean(predicted_similarity - original_similarity)}) -def train(image_embed_dim, - image_embed_url, - text_embed_url, - batch_size, - train_percent, - val_percent, - test_percent, - num_epochs, - dp_loss_type, - clip, - dp_condition_on_text_encodings, - dp_timesteps, - dp_normformer, - dp_cond_drop_prob, - dpn_depth, - dpn_dim_head, - dpn_heads, - save_interval, - save_path, - device, - RESUME, - DPRIOR_PATH, - config, - wandb_entity, - wandb_project, - learning_rate=0.001, - max_grad_norm=0.5, - weight_decay=0.01, - dropout=0.05, - amp=False): +@click.command() +@click.option("--wandb-entity", default="laion") +@click.option("--wandb-project", default="diffusion-prior") +@click.option("--wandb-dataset", default="LAION-5B") +@click.option("--wandb-arch", default="DiffusionPrior") +@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") +@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") +@click.option("--learning-rate", default=1.1e-4) +@click.option("--weight-decay", default=6.02e-2) +@click.option("--dropout", default=5e-2) +@click.option("--max-grad-norm", default=0.5) +@click.option("--batch-size", default=10**4) +@click.option("--num-epochs", default=5) +@click.option("--image-embed-dim", default=768) +@click.option("--train-percent", default=0.7) +@click.option("--val-percent", default=0.2) +@click.option("--test-percent", default=0.1) +@click.option("--dpn-depth", default=6) +@click.option("--dpn-dim-head", default=64) +@click.option("--dpn-heads", default=8) +@click.option("--dp-condition-on-text-encodings", default=False) +@click.option("--dp-timesteps", default=100) +@click.option("--dp-normformer", default=False) +@click.option("--dp-cond-drop-prob", default=0.1) +@click.option("--dp-loss-type", default="l2") +@click.option("--clip", default=None) +@click.option("--amp", default=False) +@click.option("--save-interval", default=30) +@click.option("--save-path", default="./diffusion_prior_checkpoints") +@click.option("--pretrained-model-path", default=None) +def train( + wandb_entity, + wandb_project, + wandb_dataset, + wandb_arch, + image_embed_url, + text_embed_url, + learning_rate, + weight_decay, + dropout, + max_grad_norm, + batch_size, + num_epochs, + image_embed_dim, + train_percent, + val_percent, + test_percent, + dpn_depth, + dpn_dim_head, + dpn_heads, + dp_condition_on_text_encodings, + dp_timesteps, + dp_normformer, + dp_cond_drop_prob, + dp_loss_type, + clip, + amp, + save_interval, + save_path, + pretrained_model_path +): + config = { + "learning_rate": learning_rate, + "architecture": wandb_arch, + "dataset": wandb_dataset, + "weight_decay": weight_decay, + "max_gradient_clipping_norm": max_grad_norm, + "batch_size": batch_size, + "epochs": num_epochs, + "diffusion_prior_network": { + "depth": dpn_depth, + "dim_head": dpn_dim_head, + "heads": dpn_heads, + "normformer": dp_normformer + }, + "diffusion_prior": { + "condition_on_text_encodings": dp_condition_on_text_encodings, + "timesteps": dp_timesteps, + "cond_drop_prob": dp_cond_drop_prob, + "loss_type": dp_loss_type, + "clip": clip + } + } + # Check if DPRIOR_PATH exists(saved model path) + + DPRIOR_PATH = args.pretrained_model_path + RESUME = exists(DPRIOR_PATH) + + if not RESUME: + tracker.init( + entity = wandb_entity, + project = wandb_project, + config = config + ) + + # Obtain the utilized device. + + has_cuda = torch.cuda.is_available() + if has_cuda: + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + # Training loop # diffusion prior network prior_network = DiffusionPriorNetwork( @@ -269,140 +342,5 @@ def train(image_embed_dim, end = num_data_points eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test") -@click.command() -@click.option("--wandb-entity", default="laion") -@click.option("--wandb-project", default="diffusion-prior") -@click.option("--wandb-dataset", default="LAION-5B") -@click.option("--wandb-arch", default="DiffusionPrior") -@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") -@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") -@click.option("--learning-rate", default=1.1e-4) -@click.option("--weight-decay", default=6.02e-2) -@click.option("--dropout", default=5e-2) -@click.option("--max-grad-norm", default=0.5) -@click.option("--batch-size", default=10**4) -@click.option("--num-epochs", default=5) -@click.option("--image-embed-dim", default=768) -@click.option("--train-percent", default=0.7) -@click.option("--val-percent", default=0.2) -@click.option("--test-percent", default=0.1) -@click.option("--dpn-depth", default=6) -@click.option("--dpn-dim-head", default=64) -@click.option("--dpn-heads", default=8) -@click.option("--dp-condition-on-text-encodings", default=False) -@click.option("--dp-timesteps", default=100) -@click.option("--dp-normformer", default=False) -@click.option("--dp-cond-drop-prob", default=0.1) -@click.option("--dp-loss-type", default="l2") -@click.option("--clip", default=None) -@click.option("--amp", default=False) -@click.option("--save-interval", default=30) -@click.option("--save-path", default="./diffusion_prior_checkpoints") -@click.option("--pretrained-model-path", default=None) -def main( - wandb_entity, - wandb_project, - wandb_dataset, - wandb_arch, - image_embed_url, - text_embed_url, - learning_rate, - weight_decay, - dropout, - max_grad_norm, - batch_size, - num_epochs, - image_embed_dim, - train_percent, - val_percent, - test_percent, - dpn_depth, - dpn_dim_head, - dpn_heads, - dp_condition_on_text_encodings, - dp_timesteps, - dp_normformer, - dp_cond_drop_prob, - dp_loss_type, - clip, - amp, - save_interval, - save_path, - pretrained_model_path -): - config = { - "learning_rate": learning_rate, - "architecture": wandb_arch, - "dataset": wandb_dataset, - "weight_decay": weight_decay, - "max_gradient_clipping_norm": max_grad_norm, - "batch_size": batch_size, - "epochs": num_epochs, - "diffusion_prior_network": { - "depth": dpn_depth, - "dim_head": dpn_dim_head, - "heads": dpn_heads, - "normformer": dp_normformer - }, - "diffusion_prior": { - "condition_on_text_encodings": dp_condition_on_text_encodings, - "timesteps": dp_timesteps, - "cond_drop_prob": dp_cond_drop_prob, - "loss_type": dp_loss_type, - "clip": clip - } - } - - # Check if DPRIOR_PATH exists(saved model path) - - DPRIOR_PATH = args.pretrained_model_path - RESUME = exists(DPRIOR_PATH) - - if not RESUME: - tracker.init( - entity = wandb_entity, - project = wandb_project, - config = config - ) - - # Obtain the utilized device. - - has_cuda = torch.cuda.is_available() - if has_cuda: - device = torch.device("cuda:0") - torch.cuda.set_device(device) - - # Training loop - train(image_embed_dim, - image_embed_url, - text_embed_url, - batch_size, - train_percent, - val_percent, - test_percent, - num_epochs, - dp_loss_type, - clip, - dp_condition_on_text_encodings, - dp_timesteps, - dp_normformer, - dp_cond_drop_prob, - dpn_depth, - dpn_dim_head, - dpn_heads, - save_interval, - save_path, - device, - RESUME, - DPRIOR_PATH, - config, - wandb_entity, - wandb_project, - learning_rate, - max_grad_norm, - weight_decay, - dropout, - amp) - if __name__ == "__main__": - main() + train()