diff --git a/dalle2_pytorch/trackers.py b/dalle2_pytorch/trackers.py index 6f86ede..a83f6d4 100644 --- a/dalle2_pytorch/trackers.py +++ b/dalle2_pytorch/trackers.py @@ -528,8 +528,12 @@ class Tracker: elif save_type == 'model': if isinstance(trainer, DiffusionPriorTrainer): prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior - state_dict = trainer.accelerator.unwrap_model(prior).state_dict() - torch.save(state_dict, file_path) + prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior) + # Remove CLIP if it is part of the model + original_clip = prior.clip + prior.clip = None + model_state_dict = prior.state_dict() + prior.clip = original_clip elif isinstance(trainer, DecoderTrainer): decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder) # Remove CLIP if it is part of the model