mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 12:04:24 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79e2a3bc77 | ||
|
|
544cdd0b29 | ||
|
|
349aaca56f | ||
|
|
3ee3c56d2a |
@@ -527,25 +527,31 @@ class NoiseScheduler(nn.Module):
|
||||
# diffusion prior
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
def __init__(self, dim, eps = 1e-5, stable = False):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.stable = stable
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x / x.amax(dim = -1, keepdim = True).detach()
|
||||
if self.stable:
|
||||
x = x / x.amax(dim = -1, keepdim = True).detach()
|
||||
|
||||
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = -1, keepdim = True)
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
def __init__(self, dim, eps = 1e-5, stable = False):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.stable = stable
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = x / x.amax(dim = 1, keepdim = True).detach()
|
||||
if self.stable:
|
||||
x = x / x.amax(dim = 1, keepdim = True).detach()
|
||||
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
@@ -669,7 +675,7 @@ class Attention(nn.Module):
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
rotary_emb = None,
|
||||
pb_relax_alpha = 32 ** 2
|
||||
pb_relax_alpha = 128
|
||||
):
|
||||
super().__init__()
|
||||
self.pb_relax_alpha = pb_relax_alpha
|
||||
@@ -760,6 +766,7 @@ class CausalTransformer(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
ff_mult = 4,
|
||||
norm_in = False,
|
||||
norm_out = True,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
@@ -768,6 +775,8 @@ class CausalTransformer(nn.Module):
|
||||
rotary_emb = True
|
||||
):
|
||||
super().__init__()
|
||||
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
|
||||
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
|
||||
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||
@@ -779,12 +788,14 @@ class CausalTransformer(nn.Module):
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
n, device = x.shape[1], x.device
|
||||
|
||||
x = self.init_norm(x)
|
||||
|
||||
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
@@ -884,7 +895,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
if remainder > 0:
|
||||
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
|
||||
mask = F.pad(mask, (0, remainder), value = 0.)
|
||||
mask = F.pad(mask, (0, remainder), value = False)
|
||||
|
||||
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
||||
dim_head: int = 64
|
||||
heads: int = 8
|
||||
ff_mult: int = 4
|
||||
norm_in: bool = False
|
||||
norm_out: bool = True
|
||||
attn_dropout: float = 0.
|
||||
ff_dropout: float = 0.
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.23.0'
|
||||
__version__ = '0.23.3'
|
||||
|
||||
@@ -323,7 +323,7 @@ def train(
|
||||
last_snapshot = sample
|
||||
|
||||
if next_task == 'train':
|
||||
for i, (img, emb, txt) in enumerate(trainer.train_loader):
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
||||
# We want to count the total number of samples across all processes
|
||||
sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||
@@ -358,6 +358,7 @@ def train(
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
forward_params['text'] = tokenized_texts
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
@@ -416,7 +417,7 @@ def train(
|
||||
timer = Timer()
|
||||
accelerator.wait_for_everyone()
|
||||
i = 0
|
||||
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
|
||||
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
|
||||
val_sample_length_tensor[0] = len(img)
|
||||
all_samples = accelerator.gather(val_sample_length_tensor)
|
||||
total_samples = all_samples.sum().item()
|
||||
|
||||
Reference in New Issue
Block a user