Compare commits

...

10 Commits

5 changed files with 73 additions and 59 deletions

View File

@@ -45,6 +45,7 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
- <a href="https://github.com/malumadev">MalumaDev</a> for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
@@ -420,7 +421,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`
Working example below

View File

@@ -165,7 +165,7 @@ def unnormalize_zero_to_one(normed_img):
# clip related adapters
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings'])
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module):
@@ -226,7 +226,7 @@ class XClipAdapter(BaseClipAdapter):
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
return EmbeddedText(l2norm(text_embed), text_encodings)
@torch.no_grad()
def embed_image(self, image):
@@ -262,7 +262,7 @@ class CoCaAdapter(BaseClipAdapter):
text_mask = text != 0
text_embed, text_encodings = self.clip.embed_text(text)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(text_embed, text_encodings, text_mask)
return EmbeddedText(text_embed, text_encodings)
@torch.no_grad()
def embed_image(self, image):
@@ -323,7 +323,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_encodings = self.text_encodings
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
del self.text_encodings
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
@torch.no_grad()
def embed_image(self, image):
@@ -760,6 +760,7 @@ class CausalTransformer(nn.Module):
dim_head = 64,
heads = 8,
ff_mult = 4,
norm_in = False,
norm_out = True,
attn_dropout = 0.,
ff_dropout = 0.,
@@ -768,6 +769,8 @@ class CausalTransformer(nn.Module):
rotary_emb = True
):
super().__init__()
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
@@ -782,17 +785,15 @@ class CausalTransformer(nn.Module):
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward(
self,
x,
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
):
def forward(self, x):
n, device = x.shape[1], x.device
x = self.init_norm(x)
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
for attn, ff in self.layers:
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = attn(x, attn_bias = attn_bias) + x
x = ff(x) + x
out = self.norm(x)
@@ -806,6 +807,7 @@ class DiffusionPriorNetwork(nn.Module):
num_time_embeds = 1,
num_image_embeds = 1,
num_text_embeds = 1,
max_text_len = 256,
**kwargs
):
super().__init__()
@@ -831,6 +833,11 @@ class DiffusionPriorNetwork(nn.Module):
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
# dalle1 learned padding strategy
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
def forward_with_cond_scale(
self,
*args,
@@ -852,7 +859,6 @@ class DiffusionPriorNetwork(nn.Module):
*,
text_embed,
text_encodings = None,
mask = None,
cond_drop_prob = 0.
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
@@ -870,9 +876,28 @@ class DiffusionPriorNetwork(nn.Module):
if not exists(text_encodings):
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
mask = torch.any(text_encodings != 0., dim = -1)
if not exists(mask):
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
# replace any padding in the text encodings with learned padding tokens unique across position
text_encodings = text_encodings[:, :self.max_text_len]
mask = mask[:, :self.max_text_len]
text_len = text_encodings.shape[-2]
remainder = self.max_text_len - text_len
if remainder > 0:
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
mask = F.pad(mask, (0, remainder), value = False)
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
text_encodings = torch.where(
rearrange(mask, 'b n -> b n 1'),
text_encodings,
null_text_embeds
)
# classifier free guidance
@@ -889,9 +914,8 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask):
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.to_time_embeds(diffusion_timesteps)
@@ -907,7 +931,7 @@ class DiffusionPriorNetwork(nn.Module):
# attend
tokens = self.causal_transformer(tokens, mask = mask)
tokens = self.causal_transformer(tokens)
# get learned query, which should predict the image embedding (per DDPM timestep)
@@ -1152,12 +1176,12 @@ class DiffusionPrior(nn.Module):
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_embed, text_encodings = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
text_cond = {**text_cond, 'text_encodings': text_encodings}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
@@ -1185,7 +1209,6 @@ class DiffusionPrior(nn.Module):
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
image_embed = None,
text_encodings = None, # as well as CLIP text encodings
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
*args,
**kwargs
):
@@ -1199,14 +1222,13 @@ class DiffusionPrior(nn.Module):
# calculate text conditionings, based on what is passed in
if exists(text):
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_embed, text_encodings = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
text_cond = {**text_cond, 'text_encodings': text_encodings}
# timestep conditioning from ddpm
@@ -1745,7 +1767,6 @@ class Unet(nn.Module):
image_embed,
lowres_cond_img = None,
text_encodings = None,
text_mask = None,
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None,
@@ -1817,23 +1838,27 @@ class Unet(nn.Module):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
text_mask = torch.any(text_encodings != 0., dim = -1)
text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, :self.max_text_len]
text_mask = text_mask[:, :self.max_text_len]
text_tokens_len = text_tokens.shape[1]
remainder = self.max_text_len - text_tokens_len
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
text_mask = F.pad(text_mask, (0, remainder), value = False)
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}. text encoding is of shape {text_encodings.shape}'
text_keep_mask = text_mask & text_keep_mask
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
@@ -2215,10 +2240,10 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s
return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
@@ -2250,16 +2275,16 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
device = self.device
b = shape[0]
@@ -2275,7 +2300,6 @@ class Decoder(nn.Module):
torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed,
text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start,
@@ -2288,7 +2312,7 @@ class Decoder(nn.Module):
return unnormalize_img
@torch.no_grad()
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
@@ -2304,7 +2328,7 @@ class Decoder(nn.Module):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
@@ -2343,7 +2367,7 @@ class Decoder(nn.Module):
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
@@ -2361,7 +2385,6 @@ class Decoder(nn.Module):
times,
image_embed = image_embed,
text_encodings = text_encodings,
text_mask = text_mask,
lowres_cond_img = lowres_cond_img,
image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
@@ -2421,7 +2444,6 @@ class Decoder(nn.Module):
self,
image_embed = None,
text = None,
text_mask = None,
text_encodings = None,
batch_size = 1,
cond_scale = 1.,
@@ -2435,14 +2457,11 @@ class Decoder(nn.Module):
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip)
_, text_encodings, text_mask = self.clip.embed_text(text)
_, text_encodings = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
if self.condition_on_text_encodings:
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
img = None
is_cuda = next(self.parameters()).is_cuda
@@ -2468,7 +2487,6 @@ class Decoder(nn.Module):
shape,
image_embed = image_embed,
text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale,
predict_x_start = predict_x_start,
learned_variance = learned_variance,
@@ -2492,7 +2510,6 @@ class Decoder(nn.Module):
text = None,
image_embed = None,
text_encodings = None,
text_mask = None,
unet_number = None,
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
):
@@ -2521,14 +2538,11 @@ class Decoder(nn.Module):
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text)
_, text_encodings = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
if self.condition_on_text_encodings:
text_mask = default(text_mask, lambda: torch.any(text_encodings != 0., dim = -1))
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)
@@ -2547,7 +2561,7 @@ class Decoder(nn.Module):
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
if not return_lowres_cond_image:
return losses

View File

@@ -129,6 +129,7 @@ class AdapterConfig(BaseModel):
class DiffusionPriorNetworkConfig(BaseModel):
dim: int
depth: int
max_text_len: int = None
num_timesteps: int = None
num_time_embeds: int = 1
num_image_embeds: int = 1
@@ -136,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
dim_head: int = 64
heads: int = 8
ff_mult: int = 4
norm_in: bool = False
norm_out: bool = True
attn_dropout: float = 0.
ff_dropout: float = 0.

View File

@@ -1 +1 @@
__version__ = '0.21.0'
__version__ = '0.23.2'

View File

@@ -126,9 +126,9 @@ def report_cosine_sims(
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_embedding, text_encodings = trainer.embed_text(text_data)
text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
text_embed=text_embedding, text_encodings=text_encodings
)
else:
text_embedding = text_data
@@ -146,15 +146,12 @@ def report_cosine_sims(
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
text_mask_shuffled = text_mask[rolled_idx]
else:
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(
text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
text_encodings=text_encodings_shuffled
)
# prepare the text embedding