From 924455d97d0e0230ba6a34c5d3af792d272af481 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 11 May 2022 19:56:54 -0700 Subject: [PATCH] align the ema model device back after sampling from the cascading ddpm in the decoder --- dalle2_pytorch/train.py | 9 +++++++++ setup.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index 86b73e1..325ee3f 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -105,6 +105,10 @@ class EMA(nn.Module): self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('step', torch.tensor([0.])) + def restore_ema_model_device(self): + device = self.initted.device + self.ema_model.to(device) + def update(self): self.step += 1 @@ -305,6 +309,11 @@ class DecoderTrainer(nn.Module): if self.use_ema: self.decoder.unets = trainable_unets # restore original training unets + + # cast the ema_model unets back to original device + for ema in self.ema_unets: + ema.restore_ema_model_device() + return output def forward( diff --git a/setup.py b/setup.py index 7c667b8..51d842e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.11', + version = '0.2.12', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',