Compare commits

...

7 Commits

Author SHA1 Message Date
Phil Wang
3bdf85a5e9 bring in two tricks from the cogview paper for reducing the chances of overflow, for attention and layernorm 2022-07-05 14:21:21 -07:00
Phil Wang
e1fe3089df do bias-less layernorm manually 2022-07-05 13:09:58 -07:00
Phil Wang
6d477d7654 link to dalle2 laion 2022-07-05 11:43:07 -07:00
Phil Wang
531fe4b62f status 2022-07-05 10:46:55 -07:00
Phil Wang
ec5a77fc55 0.15.4 2022-07-02 08:56:34 -07:00
Aidan Dempster
fac63c61bc Fixed variable naming issue (#183) 2022-07-02 08:56:03 -07:00
Phil Wang
3d23ba4aa5 add ability to specify full self attention on specific stages in the unet 2022-07-01 10:22:07 -07:00
4 changed files with 71 additions and 34 deletions

View File

@@ -20,18 +20,20 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
<img src="./samples/oxford.png" width="600px" />
<img src="./samples/oxford.png" width="450px" />
*ongoing at 21k steps*
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
- <a href="https://github.com/rom1504">Romain</a> has scaled up training to 800 GPUs with the available scripts without any issues
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
- DALL-E 2 🚧
- DALL-E 2 🚧 - <a href="https://github.com/LAION-AI/dalle2-laion">DALL-E 2 Laion repository</a>
## Appreciation

View File

@@ -63,11 +63,16 @@ def default(val, d):
return val
return d() if callable(d) else d
def cast_tuple(val, length = 1):
def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else ((val,) * length)
out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(out) == length
return out
def module_device(module):
return next(module.parameters()).device
@@ -485,14 +490,15 @@ class NoiseScheduler(nn.Module):
# diffusion prior
class LayerNorm(nn.Module):
def __init__(self, dim):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
@@ -503,8 +509,7 @@ class ChanLayerNorm(nn.Module):
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
return (x - mean) * (var + self.eps).rsqrt() * self.g
class Residual(nn.Module):
def __init__(self, fn):
@@ -624,10 +629,13 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
rotary_emb = None
rotary_emb = None,
pb_relax_alpha = 32 ** 2
):
super().__init__()
self.scale = dim_head ** -0.5
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.heads = heads
inner_dim = dim_head * heads
@@ -691,6 +699,9 @@ class Attention(nn.Module):
# attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.dropout(attn)
@@ -1205,10 +1216,12 @@ class CrossAttention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
norm_context = False
norm_context = False,
pb_relax_alpha = 32 ** 2
):
super().__init__()
self.scale = dim_head ** -0.5
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.heads = heads
inner_dim = dim_head * heads
@@ -1254,6 +1267,9 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1, dtype = torch.float32)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
@@ -1341,6 +1357,7 @@ class Unet(nn.Module):
dim_mults=(1, 2, 4, 8),
channels = 3,
channels_out = None,
self_attn = False,
attn_dim_head = 32,
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
@@ -1387,6 +1404,8 @@ class Unet(nn.Module):
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
num_stages = len(in_out)
# time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim)
@@ -1450,14 +1469,16 @@ class Unet(nn.Module):
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
self_attn = cast_tuple(self_attn, num_stages)
create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs)))
# resnet block klass
resnet_groups = cast_tuple(resnet_groups, len(in_out))
resnet_groups = cast_tuple(resnet_groups, num_stages)
top_level_resnet_group = first(resnet_groups)
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
assert len(resnet_groups) == len(in_out)
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages)
# downsample klass
@@ -1479,9 +1500,9 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
skip_connect_dims = [] # keeping track of skip connection dimensions
skip_connect_dims = [] # keeping track of skip connection dimensions
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
@@ -1489,30 +1510,42 @@ class Unet(nn.Module):
dim_layer = dim_out if memory_efficient else dim_in
skip_connect_dims.append(dim_layer)
attention = nn.Identity()
if layer_self_attn:
attention = create_self_attn(dim_layer)
elif sparse_attn:
attention = Residual(LinearAttention(dim_layer, **attn_kwargs))
self.downs.append(nn.ModuleList([
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_layer, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
]))
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_attn = create_self_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
is_last = ind >= (len(in_out) - 1)
layer_cond_dim = cond_dim if not is_last else None
skip_connect_dim = skip_connect_dims.pop()
attention = nn.Identity()
if layer_self_attn:
attention = create_self_attn(dim_out)
elif sparse_attn:
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
]))
@@ -1690,18 +1723,19 @@ class Unet(nn.Module):
hiddens = []
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
x = init_block(x, t, c)
x = sparse_attn(x)
hiddens.append(x)
for resnet_block in resnet_blocks:
x = resnet_block(x, t, c)
hiddens.append(x)
x = attn(x)
hiddens.append(x)
if exists(post_downsample):
x = post_downsample(x)
@@ -1714,15 +1748,15 @@ class Unet(nn.Module):
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
for init_block, resnet_blocks, attn, upsample in self.ups:
x = connect_skip(x)
x = init_block(x, t, c)
x = sparse_attn(x)
for resnet_block in resnet_blocks:
x = connect_skip(x)
x = resnet_block(x, t, c)
x = attn(x)
x = upsample(x)
x = torch.cat((x, r), dim = 1)

View File

@@ -216,6 +216,7 @@ class UnetConfig(BaseModel):
cond_on_text_encodings: bool = None
cond_dim: int = None
channels: int = 3
self_attn: ListOrTuple(int)
attn_dim_head: int = 32
attn_heads: int = 16
@@ -345,17 +346,17 @@ class TrainDecoderConfig(BaseModel):
img_emb_url = data_config.img_embeddings_url
text_emb_url = data_config.text_embeddings_url
if using_text_encodings:
if using_text_embeddings:
# Then we need some way to get the embeddings
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
if using_clip:
if using_text_encodings:
if using_text_embeddings:
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
else:
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
if text_emb_url:
assert using_text_encodings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return values

View File

@@ -1 +1 @@
__version__ = '0.15.2'
__version__ = '0.16.1'