mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 21:34:21 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bdf85a5e9 | ||
|
|
e1fe3089df | ||
|
|
6d477d7654 | ||
|
|
531fe4b62f | ||
|
|
ec5a77fc55 | ||
|
|
fac63c61bc |
@@ -20,18 +20,20 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
|
||||
|
||||
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
|
||||
|
||||
<img src="./samples/oxford.png" width="600px" />
|
||||
<img src="./samples/oxford.png" width="450px" />
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||
|
||||
- <a href="https://github.com/rom1504">Romain</a> has scaled up training to 800 GPUs with the available scripts without any issues
|
||||
|
||||
## Pre-Trained Models
|
||||
|
||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
|
||||
- DALL-E 2 🚧
|
||||
- DALL-E 2 🚧 - <a href="https://github.com/LAION-AI/dalle2-laion">DALL-E 2 Laion repository</a>
|
||||
|
||||
## Appreciation
|
||||
|
||||
|
||||
@@ -490,14 +490,15 @@ class NoiseScheduler(nn.Module):
|
||||
# diffusion prior
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.register_buffer("beta", torch.zeros(dim))
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
||||
|
||||
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = -1, keepdim = True)
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
@@ -508,8 +509,7 @@ class ChanLayerNorm(nn.Module):
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
@@ -629,10 +629,13 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
rotary_emb = None
|
||||
rotary_emb = None,
|
||||
pb_relax_alpha = 32 ** 2
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.pb_relax_alpha = pb_relax_alpha
|
||||
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
||||
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
@@ -696,6 +699,9 @@ class Attention(nn.Module):
|
||||
|
||||
# attention
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
sim = sim * self.pb_relax_alpha
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
@@ -1210,10 +1216,12 @@ class CrossAttention(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
norm_context = False
|
||||
norm_context = False,
|
||||
pb_relax_alpha = 32 ** 2
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.pb_relax_alpha = pb_relax_alpha
|
||||
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
@@ -1259,6 +1267,9 @@ class CrossAttention(nn.Module):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
sim = sim * self.pb_relax_alpha
|
||||
|
||||
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
@@ -346,17 +346,17 @@ class TrainDecoderConfig(BaseModel):
|
||||
img_emb_url = data_config.img_embeddings_url
|
||||
text_emb_url = data_config.text_embeddings_url
|
||||
|
||||
if using_text_encodings:
|
||||
if using_text_embeddings:
|
||||
# Then we need some way to get the embeddings
|
||||
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
|
||||
|
||||
if using_clip:
|
||||
if using_text_encodings:
|
||||
if using_text_embeddings:
|
||||
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
|
||||
else:
|
||||
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||
|
||||
if text_emb_url:
|
||||
assert using_text_encodings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||
|
||||
return values
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.15.3'
|
||||
__version__ = '0.16.1'
|
||||
|
||||
Reference in New Issue
Block a user