From e928ae5c346846e87257ff6db5a08cc8d7c58980 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 6 Jul 2022 12:47:48 -0700 Subject: [PATCH] default the device to the device that the diffusion prior parameters are on, if the trainer was never given the accelerator nor device --- dalle2_pytorch/trainer.py | 22 +++++++++++++++------- dalle2_pytorch/version.py | 2 +- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 3b14ab7..f288b8e 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -173,14 +173,26 @@ class DiffusionPriorTrainer(nn.Module): super().__init__() assert isinstance(diffusion_prior, DiffusionPrior) assert not exists(accelerator) or isinstance(accelerator, Accelerator) - assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device." ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) + # verbosity + + self.verbose = verbose + # assign some helpful member vars + self.accelerator = accelerator - self.device = accelerator.device if exists(accelerator) else device self.text_conditioned = diffusion_prior.condition_on_text_encodings + # setting the device + + if not exists(accelerator) and not exists(device): + diffusion_prior_device = next(diffusion_prior.parameters()).device + self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}') + self.device = diffusion_prior_device + else: + self.device = accelerator.device if exists(accelerator) else device + # save model self.diffusion_prior = diffusion_prior @@ -214,13 +226,9 @@ class DiffusionPriorTrainer(nn.Module): self.max_grad_norm = max_grad_norm - # verbosity - - self.verbose = verbose - # track steps internally - self.register_buffer('step', torch.tensor([0])) + self.register_buffer('step', torch.tensor([0], device = self.device)) # accelerator wrappers diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 029a258..617a906 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.10' +__version__ = '0.16.12'