diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index 86b73e1..325ee3f 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -105,6 +105,10 @@ class EMA(nn.Module): self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('step', torch.tensor([0.])) + def restore_ema_model_device(self): + device = self.initted.device + self.ema_model.to(device) + def update(self): self.step += 1 @@ -305,6 +309,11 @@ class DecoderTrainer(nn.Module): if self.use_ema: self.decoder.unets = trainable_unets # restore original training unets + + # cast the ema_model unets back to original device + for ema in self.ema_unets: + ema.restore_ema_model_device() + return output def forward( diff --git a/setup.py b/setup.py index 7c667b8..51d842e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.11', + version = '0.2.12', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',