add diffusion prior trainer, which automatically takes care of the exponential moving average (training and sampling), as well as mixed precision, gradient clipping

This commit is contained in:
Phil Wang
2022-05-06 08:11:09 -07:00
parent 878b555ef7
commit 98df1ba51e
5 changed files with 154 additions and 4 deletions

View File

@@ -845,6 +845,18 @@ class DiffusionPrior(BaseGaussianDiffusion):
loss = self.loss_fn(pred, target)
return loss
@torch.inference_mode()
@eval_decorator
def sample_batch_size(self, batch_size, text_cond):
device = self.betas.device
shape = (batch_size, self.image_embed_dim)
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
@torch.inference_mode()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2):