From 56883910fbc87304cd79f80bfa252618f451324b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 20 Jun 2022 11:14:50 -0700 Subject: [PATCH] cleanup --- README.md | 27 ------------------- dalle2_pytorch/trainer.py | 56 +++++---------------------------------- 2 files changed, 6 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index 585fbd5..9e63ee0 100644 --- a/README.md +++ b/README.md @@ -1017,33 +1017,6 @@ The most significant parameters for the script are as follows: - `clip`, default = `None` # Signals the prior to use pre-computed embeddings -#### Loading and Saving the DiffusionPrior model - -Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory. - -```python -from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model -``` - -##### Loading - - load_diffusion_model(dprior_path, device) - dprior_path : path to saved model(.pth) - device : the cuda device you're running on - -##### Saving - - save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim) - save_path : path to save at - model : object of Diffusion_Prior - optimizer : optimizer object - see train_diffusion_prior.py for how to create one. - e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) - scaler : a GradScaler object. - e.g: scaler = GradScaler(enabled=amp) - config : config object created in train_diffusion_prior.py - see file for example. - image_embed_dim - the dimension of the image_embedding - e.g: 768 - ## CLI (wip) ```bash diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 83aa3d8..9a5bad8 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -145,44 +145,6 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs): chunk_size_frac = chunk_size / batch_size yield chunk_size_frac, (chunked_args, chunked_kwargs) -# saving and loading functions - -# for diffusion prior - -def load_diffusion_model(dprior_path, device): - dprior_path = Path(dprior_path) - assert dprior_path.exists(), 'Dprior model file does not exist' - loaded_obj = torch.load(str(dprior_path), map_location='cpu') - - # Get hyperparameters of loaded model - dpn_config = loaded_obj['hparams']['diffusion_prior_network'] - dp_config = loaded_obj['hparams']['diffusion_prior'] - image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim'] - - # Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters - - # DiffusionPriorNetwork - prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device) - - # DiffusionPrior with text embeddings and image embeddings pre-computed - diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device) - - # Load state dict from saved model - diffusion_prior.load_state_dict(loaded_obj['model']) - - return diffusion_prior, loaded_obj - -def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim): - # Saving State Dict - print_ribbon('Saving checkpoint') - - state_dict = dict(model=model.state_dict(), - optimizer=optimizer.state_dict(), - scaler=scaler.state_dict(), - hparams = config, - image_embed_dim = {"image_embed_dim":image_embed_dim}) - torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth') - # exponential moving average wrapper class EMA(nn.Module): @@ -505,26 +467,20 @@ class DiffusionPriorTrainer(nn.Module): @cast_torch_tensor @prior_sample_in_chunks def p_sample_loop(self, *args, **kwargs): - if self.use_ema: - return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) - else: - return self.diffusion_prior.p_sample_loop(*args, **kwargs) + model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior + return model.p_sample_loop(*args, **kwargs) @torch.no_grad() @cast_torch_tensor @prior_sample_in_chunks def sample(self, *args, **kwargs): - if self.use_ema: - return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) - else: - return self.diffusion_prior.sample(*args, **kwargs) + model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior + return model.sample(*args, **kwargs) @torch.no_grad() def sample_batch_size(self, *args, **kwargs): - if self.use_ema: - return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs) - else: - return self.diffusion_prior.sample_batch_size(*args, **kwargs) + model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior + return model.sample_batch_size(*args, **kwargs) @torch.no_grad() @cast_torch_tensor