back to no_grad for now, also keep track and restore unet devices in one_unet_in_gpu contextmanager

This commit is contained in:
Phil Wang
2022-05-16 09:36:14 -07:00
parent bb151ca6b1
commit dab106d4e5
3 changed files with 18 additions and 14 deletions

View File

@@ -278,17 +278,17 @@ class DiffusionPriorTrainer(nn.Module):
self.step += 1
@torch.inference_mode()
@torch.no_grad()
@cast_torch_tensor
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode()
@torch.no_grad()
@cast_torch_tensor
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.inference_mode()
@torch.no_grad()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)