Improved upsampler training (#181)

Sampling is now possible without the first decoder unet

Non-training unets are deleted in the decoder trainer since they are never used and it is harder merge the models is they have keys in this state dict

Fixed a mistake where clip was not re-added after saving
This commit is contained in:
Aidan Dempster
2022-07-19 22:07:50 -04:00
committed by GitHub
parent 4b912a38c6
commit 4145474bab
6 changed files with 104 additions and 49 deletions

View File

@@ -530,11 +530,14 @@ class Tracker:
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
prior: DiffusionPrior = trainer.unwrap_model(prior)
# Remove CLIP if it is part of the model
original_clip = prior.clip
prior.clip = None
model_state_dict = prior.state_dict()
prior.clip = original_clip
elif isinstance(trainer, DecoderTrainer):
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# Remove CLIP if it is part of the model
original_clip = decoder.clip
decoder.clip = None
if trainer.use_ema:
trainable_unets = decoder.unets
@@ -543,6 +546,7 @@ class Tracker:
decoder.unets = trainable_unets # Swap back
else:
model_state_dict = decoder.state_dict()
decoder.clip = original_clip
else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
state_dict = {