mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Re-introduced change that was accidentally rolled back (#212)
This commit is contained in:
@@ -528,8 +528,12 @@ class Tracker:
|
|||||||
elif save_type == 'model':
|
elif save_type == 'model':
|
||||||
if isinstance(trainer, DiffusionPriorTrainer):
|
if isinstance(trainer, DiffusionPriorTrainer):
|
||||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||||
state_dict = trainer.accelerator.unwrap_model(prior).state_dict()
|
prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
|
||||||
torch.save(state_dict, file_path)
|
# Remove CLIP if it is part of the model
|
||||||
|
original_clip = prior.clip
|
||||||
|
prior.clip = None
|
||||||
|
model_state_dict = prior.state_dict()
|
||||||
|
prior.clip = original_clip
|
||||||
elif isinstance(trainer, DecoderTrainer):
|
elif isinstance(trainer, DecoderTrainer):
|
||||||
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||||
# Remove CLIP if it is part of the model
|
# Remove CLIP if it is part of the model
|
||||||
|
|||||||
Reference in New Issue
Block a user