mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
some cleanup
This commit is contained in:
@@ -39,6 +39,50 @@ def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
flank = symbol * repeat
|
||||
return f'{flank} {s} {flank}'
|
||||
|
||||
# saving and loading functions
|
||||
|
||||
# for diffusion prior
|
||||
|
||||
def load_diffusion_model(dprior_path, device):
|
||||
dprior_path = Path(dprior_path)
|
||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
||||
|
||||
# Get hyperparameters of loaded model
|
||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
||||
|
||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
||||
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
print_ribbon('Saving checkpoint')
|
||||
|
||||
state_dict = dict(model=model.state_dict(),
|
||||
optimizer=optimizer.state_dict(),
|
||||
scaler=scaler.state_dict(),
|
||||
hparams = config,
|
||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user