mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
default the device to the device that the diffusion prior parameters are on, if the trainer was never given the accelerator nor device
This commit is contained in:
@@ -173,14 +173,26 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||||
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
||||||
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
|
|
||||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||||
|
|
||||||
|
# verbosity
|
||||||
|
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
# assign some helpful member vars
|
# assign some helpful member vars
|
||||||
|
|
||||||
self.accelerator = accelerator
|
self.accelerator = accelerator
|
||||||
self.device = accelerator.device if exists(accelerator) else device
|
|
||||||
self.text_conditioned = diffusion_prior.condition_on_text_encodings
|
self.text_conditioned = diffusion_prior.condition_on_text_encodings
|
||||||
|
|
||||||
|
# setting the device
|
||||||
|
|
||||||
|
if not exists(accelerator) and not exists(device):
|
||||||
|
diffusion_prior_device = next(diffusion_prior.parameters()).device
|
||||||
|
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
|
||||||
|
self.device = diffusion_prior_device
|
||||||
|
else:
|
||||||
|
self.device = accelerator.device if exists(accelerator) else device
|
||||||
|
|
||||||
# save model
|
# save model
|
||||||
|
|
||||||
self.diffusion_prior = diffusion_prior
|
self.diffusion_prior = diffusion_prior
|
||||||
@@ -214,13 +226,9 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
# verbosity
|
|
||||||
|
|
||||||
self.verbose = verbose
|
|
||||||
|
|
||||||
# track steps internally
|
# track steps internally
|
||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0]))
|
self.register_buffer('step', torch.tensor([0], device = self.device))
|
||||||
|
|
||||||
# accelerator wrappers
|
# accelerator wrappers
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.16.10'
|
__version__ = '0.16.12'
|
||||||
|
|||||||
Reference in New Issue
Block a user