add ability to specify full self attention on specific stages in the unet

This commit is contained in:
Phil Wang
2022-07-01 10:22:07 -07:00
parent 282c35930f
commit 3d23ba4aa5
3 changed files with 42 additions and 18 deletions

View File

@@ -63,11 +63,16 @@ def default(val, d):
return val
return d() if callable(d) else d
def cast_tuple(val, length = 1):
def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else ((val,) * length)
out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(out) == length
return out
def module_device(module):
return next(module.parameters()).device
@@ -1341,6 +1346,7 @@ class Unet(nn.Module):
dim_mults=(1, 2, 4, 8),
channels = 3,
channels_out = None,
self_attn = False,
attn_dim_head = 32,
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
@@ -1387,6 +1393,8 @@ class Unet(nn.Module):
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
num_stages = len(in_out)
# time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim)
@@ -1450,14 +1458,16 @@ class Unet(nn.Module):
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
self_attn = cast_tuple(self_attn, num_stages)
create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs)))
# resnet block klass
resnet_groups = cast_tuple(resnet_groups, len(in_out))
resnet_groups = cast_tuple(resnet_groups, num_stages)
top_level_resnet_group = first(resnet_groups)
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
assert len(resnet_groups) == len(in_out)
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages)
# downsample klass
@@ -1479,9 +1489,9 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
skip_connect_dims = [] # keeping track of skip connection dimensions
skip_connect_dims = [] # keeping track of skip connection dimensions
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
@@ -1489,30 +1499,42 @@ class Unet(nn.Module):
dim_layer = dim_out if memory_efficient else dim_in
skip_connect_dims.append(dim_layer)
attention = nn.Identity()
if layer_self_attn:
attention = create_self_attn(dim_layer)
elif sparse_attn:
attention = Residual(LinearAttention(dim_layer, **attn_kwargs))
self.downs.append(nn.ModuleList([
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_layer, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
]))
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_attn = create_self_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))):
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
is_last = ind >= (len(in_out) - 1)
layer_cond_dim = cond_dim if not is_last else None
skip_connect_dim = skip_connect_dims.pop()
attention = nn.Identity()
if layer_self_attn:
attention = create_self_attn(dim_out)
elif sparse_attn:
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention,
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
]))
@@ -1690,18 +1712,19 @@ class Unet(nn.Module):
hiddens = []
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
x = init_block(x, t, c)
x = sparse_attn(x)
hiddens.append(x)
for resnet_block in resnet_blocks:
x = resnet_block(x, t, c)
hiddens.append(x)
x = attn(x)
hiddens.append(x)
if exists(post_downsample):
x = post_downsample(x)
@@ -1714,15 +1737,15 @@ class Unet(nn.Module):
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
for init_block, resnet_blocks, attn, upsample in self.ups:
x = connect_skip(x)
x = init_block(x, t, c)
x = sparse_attn(x)
for resnet_block in resnet_blocks:
x = connect_skip(x)
x = resnet_block(x, t, c)
x = attn(x)
x = upsample(x)
x = torch.cat((x, r), dim = 1)

View File

@@ -216,6 +216,7 @@ class UnetConfig(BaseModel):
cond_on_text_encodings: bool = None
cond_dim: int = None
channels: int = 3
self_attn: ListOrTuple(int)
attn_dim_head: int = 32
attn_heads: int = 16

View File

@@ -1 +1 @@
__version__ = '0.15.2'
__version__ = '0.15.3'