mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
back to no_grad for now, also keep track and restore unet devices in one_unet_in_gpu contextmanager
This commit is contained in:
@@ -278,17 +278,17 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
self.step += 1
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
def sample(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user