mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
extra insurance that diffusion prior is on the correct device, when using trainer with accelerator or device was given
This commit is contained in:
@@ -192,6 +192,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
self.device = diffusion_prior_device
|
self.device = diffusion_prior_device
|
||||||
else:
|
else:
|
||||||
self.device = accelerator.device if exists(accelerator) else device
|
self.device = accelerator.device if exists(accelerator) else device
|
||||||
|
diffusion_prior.to(self.device)
|
||||||
|
|
||||||
# save model
|
# save model
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.16.16'
|
__version__ = '0.16.17'
|
||||||
|
|||||||
Reference in New Issue
Block a user