mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
972ee973bc | ||
|
|
79e2a3bc77 | ||
|
|
544cdd0b29 | ||
|
|
349aaca56f | ||
|
|
3ee3c56d2a |
@@ -527,25 +527,31 @@ class NoiseScheduler(nn.Module):
|
||||
# diffusion prior
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
def __init__(self, dim, eps = 1e-5, stable = False):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.stable = stable
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x / x.amax(dim = -1, keepdim = True).detach()
|
||||
if self.stable:
|
||||
x = x / x.amax(dim = -1, keepdim = True).detach()
|
||||
|
||||
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = -1, keepdim = True)
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
def __init__(self, dim, eps = 1e-5, stable = False):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.stable = stable
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = x / x.amax(dim = 1, keepdim = True).detach()
|
||||
if self.stable:
|
||||
x = x / x.amax(dim = 1, keepdim = True).detach()
|
||||
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
@@ -669,7 +675,7 @@ class Attention(nn.Module):
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
rotary_emb = None,
|
||||
pb_relax_alpha = 32 ** 2
|
||||
pb_relax_alpha = 128
|
||||
):
|
||||
super().__init__()
|
||||
self.pb_relax_alpha = pb_relax_alpha
|
||||
@@ -760,6 +766,7 @@ class CausalTransformer(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
ff_mult = 4,
|
||||
norm_in = False,
|
||||
norm_out = True,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
@@ -768,6 +775,8 @@ class CausalTransformer(nn.Module):
|
||||
rotary_emb = True
|
||||
):
|
||||
super().__init__()
|
||||
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
|
||||
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
|
||||
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||
@@ -779,20 +788,18 @@ class CausalTransformer(nn.Module):
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
||||
):
|
||||
def forward(self, x):
|
||||
n, device = x.shape[1], x.device
|
||||
|
||||
x = self.init_norm(x)
|
||||
|
||||
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
||||
x = attn(x, attn_bias = attn_bias) + x
|
||||
x = ff(x) + x
|
||||
|
||||
out = self.norm(x)
|
||||
@@ -806,7 +813,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
num_time_embeds = 1,
|
||||
num_image_embeds = 1,
|
||||
num_text_embeds = 1,
|
||||
attend_all_text_encodings = True,
|
||||
max_text_len = 256,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -832,7 +839,10 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
self.attend_all_text_encodings = attend_all_text_encodings
|
||||
# dalle1 learned padding strategy
|
||||
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
@@ -872,11 +882,28 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
if not exists(text_encodings):
|
||||
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
||||
|
||||
mask = torch.any(text_encodings != 0., dim = -1)
|
||||
|
||||
if self.attend_all_text_encodings:
|
||||
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
||||
else:
|
||||
mask = torch.any(text_encodings != 0., dim = -1)
|
||||
# replace any padding in the text encodings with learned padding tokens unique across position
|
||||
|
||||
text_encodings = text_encodings[:, :self.max_text_len]
|
||||
mask = mask[:, :self.max_text_len]
|
||||
|
||||
text_len = text_encodings.shape[-2]
|
||||
remainder = self.max_text_len - text_len
|
||||
|
||||
if remainder > 0:
|
||||
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
|
||||
mask = F.pad(mask, (0, remainder), value = False)
|
||||
|
||||
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
||||
|
||||
text_encodings = torch.where(
|
||||
rearrange(mask, 'b n -> b n 1'),
|
||||
text_encodings,
|
||||
null_text_embeds
|
||||
)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
@@ -910,7 +937,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
# attend
|
||||
|
||||
tokens = self.causal_transformer(tokens, mask = mask)
|
||||
tokens = self.causal_transformer(tokens)
|
||||
|
||||
# get learned query, which should predict the image embedding (per DDPM timestep)
|
||||
|
||||
@@ -2301,6 +2328,9 @@ class Decoder(nn.Module):
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
if not is_latent_diffusion:
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
||||
alpha = alphas[time]
|
||||
alpha_next = alphas[time_next]
|
||||
|
||||
@@ -129,14 +129,15 @@ class AdapterConfig(BaseModel):
|
||||
class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim: int
|
||||
depth: int
|
||||
max_text_len: int = None
|
||||
num_timesteps: int = None
|
||||
num_time_embeds: int = 1
|
||||
num_image_embeds: int = 1
|
||||
num_text_embeds: int = 1
|
||||
attend_all_text_encodings: bool = True
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_in: bool = False
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.22.3'
|
||||
__version__ = '0.23.4'
|
||||
|
||||
@@ -323,7 +323,7 @@ def train(
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb, txt) in enumerate(trainer.train_loader):
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
@@ -358,6 +358,7 @@ def train(
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
@@ -416,7 +417,7 @@ def train(
|
||||
timer = Timer()
|
||||
accelerator.wait_for_everyone()
|
||||
i = 0
|
||||
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
|
||||
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
|
||||
val_sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
|
||||
Reference in New Issue
Block a user