diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 3b14ab7..f288b8e 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -173,14 +173,26 @@ class DiffusionPriorTrainer(nn.Module): super().__init__() assert isinstance(diffusion_prior, DiffusionPrior) assert not exists(accelerator) or isinstance(accelerator, Accelerator) - assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device." ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) + # verbosity + + self.verbose = verbose + # assign some helpful member vars + self.accelerator = accelerator - self.device = accelerator.device if exists(accelerator) else device self.text_conditioned = diffusion_prior.condition_on_text_encodings + # setting the device + + if not exists(accelerator) and not exists(device): + diffusion_prior_device = next(diffusion_prior.parameters()).device + self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}') + self.device = diffusion_prior_device + else: + self.device = accelerator.device if exists(accelerator) else device + # save model self.diffusion_prior = diffusion_prior @@ -214,13 +226,9 @@ class DiffusionPriorTrainer(nn.Module): self.max_grad_norm = max_grad_norm - # verbosity - - self.verbose = verbose - # track steps internally - self.register_buffer('step', torch.tensor([0])) + self.register_buffer('step', torch.tensor([0], device = self.device)) # accelerator wrappers diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 029a258..617a906 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.10' +__version__ = '0.16.12'