diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 2b5e2d3..fa6ed8a 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -192,6 +192,7 @@ class DiffusionPriorTrainer(nn.Module): self.device = diffusion_prior_device else: self.device = accelerator.device if exists(accelerator) else device + diffusion_prior.to(self.device) # save model diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index b086370..ada3c31 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.16' +__version__ = '0.16.17'