mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
align the ema model device back after sampling from the cascading ddpm in the decoder
This commit is contained in:
@@ -105,6 +105,10 @@ class EMA(nn.Module):
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def restore_ema_model_device(self):
|
||||
device = self.initted.device
|
||||
self.ema_model.to(device)
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
@@ -305,6 +309,11 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
if self.use_ema:
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
|
||||
# cast the ema_model unets back to original device
|
||||
for ema in self.ema_unets:
|
||||
ema.restore_ema_model_device()
|
||||
|
||||
return output
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user