mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
add diffusion prior trainer, which automatically takes care of the exponential moving average (training and sampling), as well as mixed precision, gradient clipping
This commit is contained in:
@@ -845,6 +845,18 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.inference_mode()
|
||||
@eval_decorator
|
||||
def sample_batch_size(self, batch_size, text_cond):
|
||||
device = self.betas.device
|
||||
shape = (batch_size, self.image_embed_dim)
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||
return img
|
||||
|
||||
@torch.inference_mode()
|
||||
@eval_decorator
|
||||
def sample(self, text, num_samples_per_batch = 2):
|
||||
|
||||
Reference in New Issue
Block a user