mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Val loss changes, with quite a few other changes. This is in place of the earlier PR(https://github.com/lucidrains/DALLE2-pytorch/pull/67) (#77)
* Val_loss changes - no rebased with lucidrains' master. * Val Loss changes - now rebased with lucidrains' master * train_diffusion_prior.py updates * dalle2_pytorch.py updates * __init__.py changes * Update train_diffusion_prior.py * Update dalle2_pytorch.py * Update train_diffusion_prior.py * Update train_diffusion_prior.py * Update dalle2_pytorch.py * Update train_diffusion_prior.py * Update train_diffusion_prior.py * Update train_diffusion_prior.py * Update train_diffusion_prior.py * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md
This commit is contained in:
@@ -4,6 +4,8 @@ from inspect import isfunction
|
||||
from functools import partial
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -32,6 +34,42 @@ from rotary_embedding_torch import RotaryEmbedding
|
||||
from x_clip import CLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
# Diffusion Prior model loading and saving functions
|
||||
|
||||
def load_diffusion_model(dprior_path, device ):
|
||||
|
||||
dprior_path = Path(dprior_path)
|
||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
||||
|
||||
# Get hyperparameters of loaded model
|
||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
||||
|
||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
||||
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
print("====================================== Saving checkpoint ======================================")
|
||||
state_dict = dict(model=model.state_dict(),
|
||||
optimizer=optimizer.state_dict(),
|
||||
scaler=scaler.state_dict(),
|
||||
hparams = config,
|
||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
@@ -1914,3 +1952,4 @@ class DALLE2(nn.Module):
|
||||
return images[0]
|
||||
|
||||
return images
|
||||
|
||||
|
||||
Reference in New Issue
Block a user