From 64f7be192697c1e790915bf3d17fffef2013d55d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 9 May 2022 16:50:21 -0700 Subject: [PATCH] some cleanup --- README.md | 2 +- dalle2_pytorch/__init__.py | 2 +- dalle2_pytorch/dalle2_pytorch.py | 36 -------------------------- dalle2_pytorch/train.py | 44 ++++++++++++++++++++++++++++++++ train_diffusion_prior.py | 5 ++-- 5 files changed, 49 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index b9c65c3..e5148fa 100644 --- a/README.md +++ b/README.md @@ -933,7 +933,7 @@ Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/r Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory. -## from dalle2_pytorch import load_diffusion_model, save_diffusion_model +## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model load_diffusion_model(dprior_path, device) diff --git a/dalle2_pytorch/__init__.py b/dalle2_pytorch/__init__.py index 96eebca..60987bd 100644 --- a/dalle2_pytorch/__init__.py +++ b/dalle2_pytorch/__init__.py @@ -1,4 +1,4 @@ -from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder,load_diffusion_model,save_diffusion_model +from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index fcd35ec..ffcc621 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -34,42 +34,6 @@ from rotary_embedding_torch import RotaryEmbedding from x_clip import CLIP from coca_pytorch import CoCa -# Diffusion Prior model loading and saving functions - -def load_diffusion_model(dprior_path, device ): - - dprior_path = Path(dprior_path) - assert dprior_path.exists(), 'Dprior model file does not exist' - loaded_obj = torch.load(str(dprior_path), map_location='cpu') - - # Get hyperparameters of loaded model - dpn_config = loaded_obj['hparams']['diffusion_prior_network'] - dp_config = loaded_obj['hparams']['diffusion_prior'] - image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim'] - - # Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters - - # DiffusionPriorNetwork - prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device) - - # DiffusionPrior with text embeddings and image embeddings pre-computed - diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device) - - # Load state dict from saved model - diffusion_prior.load_state_dict(loaded_obj['model']) - - return diffusion_prior - -def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim): - # Saving State Dict - print("====================================== Saving checkpoint ======================================") - state_dict = dict(model=model.state_dict(), - optimizer=optimizer.state_dict(), - scaler=scaler.state_dict(), - hparams = config, - image_embed_dim = {"image_embed_dim":image_embed_dim}) - torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth') - # helper functions def exists(val): diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index 56f6cae..d086143 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -39,6 +39,50 @@ def groupby_prefix_and_trim(prefix, d): kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs +# print helpers + +def print_ribbon(s, symbol = '=', repeat = 40): + flank = symbol * repeat + return f'{flank} {s} {flank}' + +# saving and loading functions + +# for diffusion prior + +def load_diffusion_model(dprior_path, device): + dprior_path = Path(dprior_path) + assert dprior_path.exists(), 'Dprior model file does not exist' + loaded_obj = torch.load(str(dprior_path), map_location='cpu') + + # Get hyperparameters of loaded model + dpn_config = loaded_obj['hparams']['diffusion_prior_network'] + dp_config = loaded_obj['hparams']['diffusion_prior'] + image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim'] + + # Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters + + # DiffusionPriorNetwork + prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device) + + # DiffusionPrior with text embeddings and image embeddings pre-computed + diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device) + + # Load state dict from saved model + diffusion_prior.load_state_dict(loaded_obj['model']) + + return diffusion_prior + +def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim): + # Saving State Dict + print_ribbon('Saving checkpoint') + + state_dict = dict(model=model.state_dict(), + optimizer=optimizer.state_dict(), + scaler=scaler.state_dict(), + hparams = config, + image_embed_dim = {"image_embed_dim":image_embed_dim}) + torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth') + # exponential moving average wrapper class EMA(nn.Module): diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index f2f8ad2..c513c12 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -6,7 +6,8 @@ import numpy as np import torch from torch import nn from embedding_reader import EmbeddingReader -from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, load_diffusion_model, save_diffusion_model +from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork +from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon from dalle2_pytorch.optimizer import get_optimizer from torch.cuda.amp import autocast,GradScaler @@ -153,7 +154,7 @@ def train(image_embed_dim, os.makedirs(save_path) # Get image and text embeddings from the servers - print("==============Downloading embeddings - image and text====================") + print_ribbon("Downloading embeddings - image and text") image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy") text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy") num_data_points = text_reader.count