mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 21:04:28 +01:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79e2a3bc77 | ||
|
|
544cdd0b29 | ||
|
|
349aaca56f | ||
|
|
3ee3c56d2a | ||
|
|
cd26c6b17d | ||
|
|
775abc4df6 | ||
|
|
11b1d533a0 |
@@ -527,25 +527,31 @@ class NoiseScheduler(nn.Module):
|
|||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
def __init__(self, dim, eps = 1e-5):
|
def __init__(self, dim, eps = 1e-5, stable = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
self.stable = stable
|
||||||
self.g = nn.Parameter(torch.ones(dim))
|
self.g = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x / x.amax(dim = -1, keepdim = True).detach()
|
if self.stable:
|
||||||
|
x = x / x.amax(dim = -1, keepdim = True).detach()
|
||||||
|
|
||||||
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
||||||
mean = torch.mean(x, dim = -1, keepdim = True)
|
mean = torch.mean(x, dim = -1, keepdim = True)
|
||||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||||
|
|
||||||
class ChanLayerNorm(nn.Module):
|
class ChanLayerNorm(nn.Module):
|
||||||
def __init__(self, dim, eps = 1e-5):
|
def __init__(self, dim, eps = 1e-5, stable = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
self.stable = stable
|
||||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x / x.amax(dim = 1, keepdim = True).detach()
|
if self.stable:
|
||||||
|
x = x / x.amax(dim = 1, keepdim = True).detach()
|
||||||
|
|
||||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||||
@@ -669,7 +675,7 @@ class Attention(nn.Module):
|
|||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
causal = False,
|
causal = False,
|
||||||
rotary_emb = None,
|
rotary_emb = None,
|
||||||
pb_relax_alpha = 32 ** 2
|
pb_relax_alpha = 128
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pb_relax_alpha = pb_relax_alpha
|
self.pb_relax_alpha = pb_relax_alpha
|
||||||
@@ -760,6 +766,7 @@ class CausalTransformer(nn.Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
ff_mult = 4,
|
ff_mult = 4,
|
||||||
|
norm_in = False,
|
||||||
norm_out = True,
|
norm_out = True,
|
||||||
attn_dropout = 0.,
|
attn_dropout = 0.,
|
||||||
ff_dropout = 0.,
|
ff_dropout = 0.,
|
||||||
@@ -768,6 +775,8 @@ class CausalTransformer(nn.Module):
|
|||||||
rotary_emb = True
|
rotary_emb = True
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
|
||||||
|
|
||||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||||
|
|
||||||
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||||
@@ -779,20 +788,18 @@ class CausalTransformer(nn.Module):
|
|||||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||||
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
||||||
|
|
||||||
def forward(
|
def forward(self, x):
|
||||||
self,
|
|
||||||
x,
|
|
||||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
|
||||||
):
|
|
||||||
n, device = x.shape[1], x.device
|
n, device = x.shape[1], x.device
|
||||||
|
|
||||||
|
x = self.init_norm(x)
|
||||||
|
|
||||||
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||||
|
|
||||||
for attn, ff in self.layers:
|
for attn, ff in self.layers:
|
||||||
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
x = attn(x, attn_bias = attn_bias) + x
|
||||||
x = ff(x) + x
|
x = ff(x) + x
|
||||||
|
|
||||||
out = self.norm(x)
|
out = self.norm(x)
|
||||||
@@ -806,6 +813,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
num_time_embeds = 1,
|
num_time_embeds = 1,
|
||||||
num_image_embeds = 1,
|
num_image_embeds = 1,
|
||||||
num_text_embeds = 1,
|
num_text_embeds = 1,
|
||||||
|
max_text_len = 256,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -831,6 +839,11 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||||
|
|
||||||
|
# dalle1 learned padding strategy
|
||||||
|
|
||||||
|
self.max_text_len = max_text_len
|
||||||
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
@@ -852,7 +865,6 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
*,
|
*,
|
||||||
text_embed,
|
text_embed,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
mask = None,
|
|
||||||
cond_drop_prob = 0.
|
cond_drop_prob = 0.
|
||||||
):
|
):
|
||||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||||
@@ -870,9 +882,29 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
if not exists(text_encodings):
|
if not exists(text_encodings):
|
||||||
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
||||||
|
|
||||||
mask = torch.any(text_encodings != 0., dim = -1)
|
mask = torch.any(text_encodings != 0., dim = -1)
|
||||||
|
|
||||||
|
# replace any padding in the text encodings with learned padding tokens unique across position
|
||||||
|
|
||||||
|
text_encodings = text_encodings[:, :self.max_text_len]
|
||||||
|
mask = mask[:, :self.max_text_len]
|
||||||
|
|
||||||
|
text_len = text_encodings.shape[-2]
|
||||||
|
remainder = self.max_text_len - text_len
|
||||||
|
|
||||||
|
if remainder > 0:
|
||||||
|
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
|
||||||
|
mask = F.pad(mask, (0, remainder), value = False)
|
||||||
|
|
||||||
|
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
||||||
|
|
||||||
|
text_encodings = torch.where(
|
||||||
|
rearrange(mask, 'b n -> b n 1'),
|
||||||
|
text_encodings,
|
||||||
|
null_text_embeds
|
||||||
|
)
|
||||||
|
|
||||||
# classifier free guidance
|
# classifier free guidance
|
||||||
|
|
||||||
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
|
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
|
||||||
@@ -905,7 +937,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
# attend
|
# attend
|
||||||
|
|
||||||
tokens = self.causal_transformer(tokens, mask = mask)
|
tokens = self.causal_transformer(tokens)
|
||||||
|
|
||||||
# get learned query, which should predict the image embedding (per DDPM timestep)
|
# get learned query, which should predict the image embedding (per DDPM timestep)
|
||||||
|
|
||||||
@@ -1812,6 +1844,7 @@ class Unet(nn.Module):
|
|||||||
text_tokens = None
|
text_tokens = None
|
||||||
|
|
||||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||||
|
assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'
|
||||||
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
||||||
|
|
||||||
text_mask = torch.any(text_encodings != 0., dim = -1)
|
text_mask = torch.any(text_encodings != 0., dim = -1)
|
||||||
|
|||||||
@@ -129,6 +129,7 @@ class AdapterConfig(BaseModel):
|
|||||||
class DiffusionPriorNetworkConfig(BaseModel):
|
class DiffusionPriorNetworkConfig(BaseModel):
|
||||||
dim: int
|
dim: int
|
||||||
depth: int
|
depth: int
|
||||||
|
max_text_len: int = None
|
||||||
num_timesteps: int = None
|
num_timesteps: int = None
|
||||||
num_time_embeds: int = 1
|
num_time_embeds: int = 1
|
||||||
num_image_embeds: int = 1
|
num_image_embeds: int = 1
|
||||||
@@ -136,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
|||||||
dim_head: int = 64
|
dim_head: int = 64
|
||||||
heads: int = 8
|
heads: int = 8
|
||||||
ff_mult: int = 4
|
ff_mult: int = 4
|
||||||
|
norm_in: bool = False
|
||||||
norm_out: bool = True
|
norm_out: bool = True
|
||||||
attn_dropout: float = 0.
|
attn_dropout: float = 0.
|
||||||
ff_dropout: float = 0.
|
ff_dropout: float = 0.
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.22.0'
|
__version__ = '0.23.3'
|
||||||
|
|||||||
@@ -323,7 +323,7 @@ def train(
|
|||||||
last_snapshot = sample
|
last_snapshot = sample
|
||||||
|
|
||||||
if next_task == 'train':
|
if next_task == 'train':
|
||||||
for i, (img, emb, txt) in enumerate(trainer.train_loader):
|
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||||
# We want to count the total number of samples across all processes
|
# We want to count the total number of samples across all processes
|
||||||
sample_length_tensor[0] = len(img)
|
sample_length_tensor[0] = len(img)
|
||||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||||
@@ -358,6 +358,7 @@ def train(
|
|||||||
else:
|
else:
|
||||||
# Then we need to pass the text instead
|
# Then we need to pass the text instead
|
||||||
tokenized_texts = tokenize(txt, truncate=True)
|
tokenized_texts = tokenize(txt, truncate=True)
|
||||||
|
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||||
forward_params['text'] = tokenized_texts
|
forward_params['text'] = tokenized_texts
|
||||||
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||||
trainer.update(unet_number=unet)
|
trainer.update(unet_number=unet)
|
||||||
@@ -416,7 +417,7 @@ def train(
|
|||||||
timer = Timer()
|
timer = Timer()
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
i = 0
|
i = 0
|
||||||
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
|
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
|
||||||
val_sample_length_tensor[0] = len(img)
|
val_sample_length_tensor[0] = len(img)
|
||||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||||
total_samples = all_samples.sum().item()
|
total_samples = all_samples.sum().item()
|
||||||
|
|||||||
Reference in New Issue
Block a user