bring back convtranspose2d upsampling, allow for nearest upsample with hyperparam, change kernel size of last conv to 1, make configurable, cleanup

This commit is contained in:
Phil Wang
2022-07-01 09:21:47 -07:00
parent 8f2466f1cd
commit a922a539de
3 changed files with 15 additions and 28 deletions

View File

@@ -1112,15 +1112,6 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022},
url = {https://arxiv.org/abs/2204.01697}
}
```
```bibtex
@article{Yu2021VectorquantizedIM,
title = {Vector-quantized Image Modeling with Improved VQGAN},

View File

@@ -1093,7 +1093,11 @@ class DiffusionPrior(nn.Module):
# decoder
def Upsample(dim, dim_out = None):
def ConvTransposeUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
def NearestUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
@@ -1256,20 +1260,6 @@ class CrossAttention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class GridAttention(nn.Module):
def __init__(self, *args, window_size = 8, **kwargs):
super().__init__()
self.window_size = window_size
self.attn = Attention(*args, **kwargs)
def forward(self, x):
h, w = x.shape[-2:]
wsz = self.window_size
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
out = self.attn(x)
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class LinearAttention(nn.Module):
def __init__(
self,
@@ -1369,6 +1359,8 @@ class Unet(nn.Module):
cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False,
scale_skip_connection = False,
nearest_upsample = False,
final_conv_kernel_size = 1,
**kwargs
):
super().__init__()
@@ -1473,6 +1465,10 @@ class Unet(nn.Module):
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# upsample klass
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
# give memory efficient unet an initial resnet block
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
@@ -1517,11 +1513,11 @@ class Unet(nn.Module):
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
]))
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
self.to_out = nn.Conv2d(dim, self.channels_out, 3, padding = 1)
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
@@ -1716,7 +1712,7 @@ class Unet(nn.Module):
x = self.mid_block2(x, t, mid_c)
connect_skip = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = connect_skip(x)

View File

@@ -1 +1 @@
__version__ = '0.14.1'
__version__ = '0.15.0'