when in doubt, make it a hyperparameter

This commit is contained in:
Phil Wang
2022-05-07 07:52:02 -07:00
parent cd5f2c1de4
commit 8f93729d19
2 changed files with 11 additions and 4 deletions

View File

@@ -807,6 +807,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False,
training_clamp_l2norm = False,
init_image_embed_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
@@ -845,6 +846,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
@@ -879,11 +881,16 @@ class DiffusionPrior(BaseGaussianDiffusion):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
image_embed = torch.randn(shape, device=device)
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
times = torch.full((b,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond)
return image_embed
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.1.6',
version = '0.1.7',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',