mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
extra insurance that diffusion prior is on the correct device, when using trainer with accelerator or device was given
This commit is contained in:
@@ -192,6 +192,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
self.device = diffusion_prior_device
|
||||
else:
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
diffusion_prior.to(self.device)
|
||||
|
||||
# save model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user