mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
align the ema model device back after sampling from the cascading ddpm in the decoder
This commit is contained in:
@@ -105,6 +105,10 @@ class EMA(nn.Module):
|
|||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
|
def restore_ema_model_device(self):
|
||||||
|
device = self.initted.device
|
||||||
|
self.ema_model.to(device)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
@@ -305,6 +309,11 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.decoder.unets = trainable_unets # restore original training unets
|
self.decoder.unets = trainable_unets # restore original training unets
|
||||||
|
|
||||||
|
# cast the ema_model unets back to original device
|
||||||
|
for ema in self.ema_unets:
|
||||||
|
ema.restore_ema_model_device()
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user