mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 01:34:19 +01:00
Re-introduced change that was accidentally rolled back (#212)
This commit is contained in:
@@ -528,8 +528,12 @@ class Tracker:
|
||||
elif save_type == 'model':
|
||||
if isinstance(trainer, DiffusionPriorTrainer):
|
||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||
state_dict = trainer.accelerator.unwrap_model(prior).state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
|
||||
# Remove CLIP if it is part of the model
|
||||
original_clip = prior.clip
|
||||
prior.clip = None
|
||||
model_state_dict = prior.state_dict()
|
||||
prior.clip = original_clip
|
||||
elif isinstance(trainer, DecoderTrainer):
|
||||
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||
# Remove CLIP if it is part of the model
|
||||
|
||||
Reference in New Issue
Block a user