mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
cleanup
This commit is contained in:
27
README.md
27
README.md
@@ -1017,33 +1017,6 @@ The most significant parameters for the script are as follows:
|
|||||||
|
|
||||||
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
- `clip`, default = `None` # Signals the prior to use pre-computed embeddings
|
||||||
|
|
||||||
#### Loading and Saving the DiffusionPrior model
|
|
||||||
|
|
||||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
|
||||||
```
|
|
||||||
|
|
||||||
##### Loading
|
|
||||||
|
|
||||||
load_diffusion_model(dprior_path, device)
|
|
||||||
dprior_path : path to saved model(.pth)
|
|
||||||
device : the cuda device you're running on
|
|
||||||
|
|
||||||
##### Saving
|
|
||||||
|
|
||||||
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
|
||||||
save_path : path to save at
|
|
||||||
model : object of Diffusion_Prior
|
|
||||||
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
|
||||||
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
|
||||||
scaler : a GradScaler object.
|
|
||||||
e.g: scaler = GradScaler(enabled=amp)
|
|
||||||
config : config object created in train_diffusion_prior.py - see file for example.
|
|
||||||
image_embed_dim - the dimension of the image_embedding
|
|
||||||
e.g: 768
|
|
||||||
|
|
||||||
## CLI (wip)
|
## CLI (wip)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -145,44 +145,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
|
|||||||
chunk_size_frac = chunk_size / batch_size
|
chunk_size_frac = chunk_size / batch_size
|
||||||
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||||
|
|
||||||
# saving and loading functions
|
|
||||||
|
|
||||||
# for diffusion prior
|
|
||||||
|
|
||||||
def load_diffusion_model(dprior_path, device):
|
|
||||||
dprior_path = Path(dprior_path)
|
|
||||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
|
||||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
|
||||||
|
|
||||||
# Get hyperparameters of loaded model
|
|
||||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
|
||||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
|
||||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
|
||||||
|
|
||||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
|
||||||
|
|
||||||
# DiffusionPriorNetwork
|
|
||||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
|
||||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
|
||||||
|
|
||||||
# Load state dict from saved model
|
|
||||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
|
||||||
|
|
||||||
return diffusion_prior, loaded_obj
|
|
||||||
|
|
||||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
|
||||||
# Saving State Dict
|
|
||||||
print_ribbon('Saving checkpoint')
|
|
||||||
|
|
||||||
state_dict = dict(model=model.state_dict(),
|
|
||||||
optimizer=optimizer.state_dict(),
|
|
||||||
scaler=scaler.state_dict(),
|
|
||||||
hparams = config,
|
|
||||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
|
||||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
|
||||||
|
|
||||||
# exponential moving average wrapper
|
# exponential moving average wrapper
|
||||||
|
|
||||||
class EMA(nn.Module):
|
class EMA(nn.Module):
|
||||||
@@ -505,26 +467,20 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
@prior_sample_in_chunks
|
@prior_sample_in_chunks
|
||||||
def p_sample_loop(self, *args, **kwargs):
|
def p_sample_loop(self, *args, **kwargs):
|
||||||
if self.use_ema:
|
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
return model.p_sample_loop(*args, **kwargs)
|
||||||
else:
|
|
||||||
return self.diffusion_prior.p_sample_loop(*args, **kwargs)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
@prior_sample_in_chunks
|
@prior_sample_in_chunks
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
if self.use_ema:
|
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
return model.sample(*args, **kwargs)
|
||||||
else:
|
|
||||||
return self.diffusion_prior.sample(*args, **kwargs)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_batch_size(self, *args, **kwargs):
|
def sample_batch_size(self, *args, **kwargs):
|
||||||
if self.use_ema:
|
model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
|
||||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
return model.sample_batch_size(*args, **kwargs)
|
||||||
else:
|
|
||||||
return self.diffusion_prior.sample_batch_size(*args, **kwargs)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
|||||||
Reference in New Issue
Block a user