mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 02:44:26 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b322ea634 | ||
|
|
ba64ea45cc | ||
|
|
64f7be1926 |
@@ -933,7 +933,7 @@ Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/r
|
|||||||
|
|
||||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
||||||
|
|
||||||
## from dalle2_pytorch import load_diffusion_model, save_diffusion_model
|
## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||||
|
|
||||||
load_diffusion_model(dprior_path, device)
|
load_diffusion_model(dprior_path, device)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder,load_diffusion_model,save_diffusion_model
|
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||||
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
|
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from functools import partial
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -34,42 +33,6 @@ from rotary_embedding_torch import RotaryEmbedding
|
|||||||
from x_clip import CLIP
|
from x_clip import CLIP
|
||||||
from coca_pytorch import CoCa
|
from coca_pytorch import CoCa
|
||||||
|
|
||||||
# Diffusion Prior model loading and saving functions
|
|
||||||
|
|
||||||
def load_diffusion_model(dprior_path, device ):
|
|
||||||
|
|
||||||
dprior_path = Path(dprior_path)
|
|
||||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
|
||||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
|
||||||
|
|
||||||
# Get hyperparameters of loaded model
|
|
||||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
|
||||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
|
||||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
|
||||||
|
|
||||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
|
||||||
|
|
||||||
# DiffusionPriorNetwork
|
|
||||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
|
||||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
|
||||||
|
|
||||||
# Load state dict from saved model
|
|
||||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
|
||||||
|
|
||||||
return diffusion_prior
|
|
||||||
|
|
||||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
|
||||||
# Saving State Dict
|
|
||||||
print("====================================== Saving checkpoint ======================================")
|
|
||||||
state_dict = dict(model=model.state_dict(),
|
|
||||||
optimizer=optimizer.state_dict(),
|
|
||||||
scaler=scaler.state_dict(),
|
|
||||||
hparams = config,
|
|
||||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
|
||||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import time
|
||||||
import copy
|
import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@@ -39,6 +40,50 @@ def groupby_prefix_and_trim(prefix, d):
|
|||||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||||
return kwargs_without_prefix, kwargs
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
# print helpers
|
||||||
|
|
||||||
|
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||||
|
flank = symbol * repeat
|
||||||
|
return f'{flank} {s} {flank}'
|
||||||
|
|
||||||
|
# saving and loading functions
|
||||||
|
|
||||||
|
# for diffusion prior
|
||||||
|
|
||||||
|
def load_diffusion_model(dprior_path, device):
|
||||||
|
dprior_path = Path(dprior_path)
|
||||||
|
assert dprior_path.exists(), 'Dprior model file does not exist'
|
||||||
|
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
||||||
|
|
||||||
|
# Get hyperparameters of loaded model
|
||||||
|
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
||||||
|
dp_config = loaded_obj['hparams']['diffusion_prior']
|
||||||
|
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
||||||
|
|
||||||
|
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
||||||
|
|
||||||
|
# DiffusionPriorNetwork
|
||||||
|
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
||||||
|
|
||||||
|
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||||
|
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
||||||
|
|
||||||
|
# Load state dict from saved model
|
||||||
|
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||||
|
|
||||||
|
return diffusion_prior
|
||||||
|
|
||||||
|
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||||
|
# Saving State Dict
|
||||||
|
print_ribbon('Saving checkpoint')
|
||||||
|
|
||||||
|
state_dict = dict(model=model.state_dict(),
|
||||||
|
optimizer=optimizer.state_dict(),
|
||||||
|
scaler=scaler.state_dict(),
|
||||||
|
hparams = config,
|
||||||
|
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
||||||
|
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||||
|
|
||||||
# exponential moving average wrapper
|
# exponential moving average wrapper
|
||||||
|
|
||||||
class EMA(nn.Module):
|
class EMA(nn.Module):
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.2.2',
|
version = '0.2.4',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from embedding_reader import EmbeddingReader
|
from embedding_reader import EmbeddingReader
|
||||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, load_diffusion_model, save_diffusion_model
|
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||||
|
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
from torch.cuda.amp import autocast,GradScaler
|
from torch.cuda.amp import autocast,GradScaler
|
||||||
|
|
||||||
@@ -153,7 +154,7 @@ def train(image_embed_dim,
|
|||||||
os.makedirs(save_path)
|
os.makedirs(save_path)
|
||||||
|
|
||||||
# Get image and text embeddings from the servers
|
# Get image and text embeddings from the servers
|
||||||
print("==============Downloading embeddings - image and text====================")
|
print_ribbon("Downloading embeddings - image and text")
|
||||||
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
||||||
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
||||||
num_data_points = text_reader.count
|
num_data_points = text_reader.count
|
||||||
|
|||||||
Reference in New Issue
Block a user