mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 02:14:26 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
924455d97d |
@@ -105,6 +105,10 @@ class EMA(nn.Module):
|
|||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
|
def restore_ema_model_device(self):
|
||||||
|
device = self.initted.device
|
||||||
|
self.ema_model.to(device)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
@@ -305,6 +309,11 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.decoder.unets = trainable_unets # restore original training unets
|
self.decoder.unets = trainable_unets # restore original training unets
|
||||||
|
|
||||||
|
# cast the ema_model unets back to original device
|
||||||
|
for ema in self.ema_unets:
|
||||||
|
ema.restore_ema_model_device()
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user