align the ema model device back after sampling from the cascading ddpm in the decoder

This commit is contained in:
Phil Wang
2022-05-11 19:56:54 -07:00
parent 6021945fc8
commit 924455d97d
2 changed files with 10 additions and 1 deletions

View File

@@ -105,6 +105,10 @@ class EMA(nn.Module):
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def update(self):
self.step += 1
@@ -305,6 +309,11 @@ class DecoderTrainer(nn.Module):
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
ema.restore_ema_model_device()
return output
def forward(

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.11',
version = '0.2.12',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',