Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
6520d17215 fix ddim to use alpha_cumprod 2022-08-30 20:35:08 -07:00
4 changed files with 17 additions and 58 deletions

View File

@@ -634,12 +634,10 @@ Alternatively, you can also use <a href="https://github.com/mlfoundations/open_c
$ pip install open-clip-torch $ pip install open-clip-torch
``` ```
Ex. using the <a href="https://laion.ai/blog/large-openclip/">SOTA Open Clip</a> model trained by <a href="https://github.com/rom1504">Romain</a>
```python ```python
from dalle2_pytorch import OpenClipAdapter from dalle2_pytorch import OpenClipAdapter
clip = OpenClipAdapter('ViT-H/14') clip = OpenClipAdapter()
``` ```
Now you'll just have to worry about training the Prior and the Decoder! Now you'll just have to worry about training the Prior and the Decoder!
@@ -1068,7 +1066,7 @@ dataloader = create_image_embedding_dataloader(
) )
for img, emb in dataloader: for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256]) print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb["img"].shape) # torch.Size([32, 512]) print(emb.shape) # torch.Size([32, 512])
# Train decoder only as shown above # Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually # Or create a dataset without a loader so you can configure it manually
@@ -1128,7 +1126,6 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865 - [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments - [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364 - [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] add simple outpainting, text-guided 2x size the image for starters
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
## Citations ## Citations

View File

@@ -100,9 +100,6 @@ def eval_decorator(fn):
return out return out
return inner return inner
def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
def is_list_str(x): def is_list_str(x):
if not isinstance(x, (list, tuple)): if not isinstance(x, (list, tuple)):
return False return False
@@ -317,10 +314,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
self.eos_id = 49407 # for handling 0 being also '!' self.eos_id = 49407 # for handling 0 being also '!'
text_attention_final = self.find_layer('ln_final') text_attention_final = self.find_layer('ln_final')
self.dim_latent_ = text_attention_final.weight.shape[0]
self.handle = text_attention_final.register_forward_hook(self._hook) self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1] self.clip_normalize = preprocess.transforms[-1]
self.cleared = False self.cleared = False
@@ -339,7 +333,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
@property @property
def dim_latent(self): def dim_latent(self):
return self.dim_latent_ return 512
@property @property
def image_size(self): def image_size(self):
@@ -412,10 +406,7 @@ class OpenClipAdapter(BaseClipAdapter):
@property @property
def image_size(self): def image_size(self):
image_size = self.clip.visual.image_size return self.clip.visual.image_size
if isinstance(image_size, tuple):
return max(image_size)
return image_size
@property @property
def image_channels(self): def image_channels(self):
@@ -971,8 +962,6 @@ class DiffusionPriorNetwork(nn.Module):
Rearrange('b (n d) -> b n d', n = num_text_embeds) Rearrange('b (n d) -> b n d', n = num_text_embeds)
) )
self.continuous_embedded_time = not exists(num_timesteps)
self.to_time_embeds = nn.Sequential( self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds) Rearrange('b (n d) -> b n d', n = num_time_embeds)
@@ -1081,7 +1070,7 @@ class DiffusionPriorNetwork(nn.Module):
null_text_embeds = self.null_text_embeds.to(text_embed.dtype) null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
text_embed = torch.where( text_embeds = torch.where(
text_keep_mask, text_keep_mask,
text_embed, text_embed,
null_text_embeds null_text_embeds
@@ -1100,9 +1089,6 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right # but let's just do it right
if self.continuous_embedded_time:
diffusion_timesteps = diffusion_timesteps.type(dtype)
time_embed = self.to_time_embeds(diffusion_timesteps) time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
@@ -1273,7 +1259,7 @@ class DiffusionPrior(nn.Module):
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.): def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1] times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist())) times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) time_pairs = list(zip(times[:-1], times[1:]))
@@ -1295,14 +1281,12 @@ class DiffusionPrior(nn.Module):
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond) pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
# derive x0
if self.predict_x_start: if self.predict_x_start:
x_start = pred x_start = pred
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
else: else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise) x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
pred_noise = pred
# clip x0 before maybe predicting noise
if not self.predict_x_start: if not self.predict_x_start:
x_start.clamp_(-1., 1.) x_start.clamp_(-1., 1.)
@@ -1310,17 +1294,6 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start and self.sampling_clamp_l2norm: if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = self.l2norm_clamp_embed(x_start) x_start = self.l2norm_clamp_embed(x_start)
# predict noise
if self.predict_x_start:
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
else:
pred_noise = pred
if time_next < 0:
image_embed = x_start
continue
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(image_embed) if time_next > 0 else 0. noise = torch.randn_like(image_embed) if time_next > 0 else 0.
@@ -1440,7 +1413,7 @@ class DiffusionPrior(nn.Module):
**kwargs **kwargs
): ):
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied' assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied' assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization' assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
if exists(image): if exists(image):
@@ -1546,8 +1519,6 @@ class SinusoidalPosEmb(nn.Module):
def forward(self, x): def forward(self, x):
dtype, device = x.dtype, x.device dtype, device = x.dtype, x.device
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
half_dim = self.dim // 2 half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb) emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
@@ -2922,25 +2893,16 @@ class Decoder(nn.Module):
pred, _ = self.parse_unet_output(learned_variance, unet_output) pred, _ = self.parse_unet_output(learned_variance, unet_output)
# predict x0
if predict_x_start: if predict_x_start:
x_start = pred x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else: else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred) x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
# maybe clip x0
if clip_denoised: if clip_denoised:
x_start = self.dynamic_threshold(x_start) x_start = self.dynamic_threshold(x_start)
# predict noise
if predict_x_start:
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
pred_noise = pred
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if not is_last_timestep else 0. noise = torch.randn_like(img) if not is_last_timestep else 0.

View File

@@ -1 +1 @@
__version__ = '1.10.8' __version__ = '1.10.2'

View File

@@ -156,7 +156,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
if text_embeddings[0] is None: if text_embeddings[0] is None:
# Generate text embeddings from text # Generate text embeddings from text
assert clip is not None, "clip is None, but text_embeddings is None" assert clip is not None, "clip is None, but text_embeddings is None"
tokenized_texts = tokenize(txts, truncate=True).to(device=device) tokenized_texts = tokenize(txts, truncate=True)
text_embed, text_encodings = clip.embed_text(tokenized_texts) text_embed, text_encodings = clip.embed_text(tokenized_texts)
sample_params["text_encodings"] = text_encodings sample_params["text_encodings"] = text_encodings
else: else:
@@ -229,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
metrics["KID_std"] = kid_std.item() metrics["KID_std"] = kid_std.item()
if exists(LPIPS): if exists(LPIPS):
# Convert from [0, 1] to [-1, 1] # Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1) renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1) renorm_generated_images = generated_images.mul(2).sub(1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync) lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device) lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images) lpips.update(renorm_real_images, renorm_generated_images)
@@ -480,7 +480,7 @@ def train(
else: else:
# Then we need to pass the text instead # Then we need to pass the text instead
assert clip is not None assert clip is not None
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device) tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
text_embed, text_encodings = clip.embed_text(tokenized_texts) text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings forward_params['text_encodings'] = text_encodings