mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
bring back convtranspose2d upsampling, allow for nearest upsample with hyperparam, change kernel size of last conv to 1, make configurable, cleanup
This commit is contained in:
@@ -1112,15 +1112,6 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@inproceedings{Tu2022MaxViTMV,
|
|
||||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
|
||||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
|
||||||
year = {2022},
|
|
||||||
url = {https://arxiv.org/abs/2204.01697}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{Yu2021VectorquantizedIM,
|
@article{Yu2021VectorquantizedIM,
|
||||||
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
||||||
|
|||||||
@@ -1093,7 +1093,11 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
def Upsample(dim, dim_out = None):
|
def ConvTransposeUpsample(dim, dim_out = None):
|
||||||
|
dim_out = default(dim_out, dim)
|
||||||
|
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
||||||
|
|
||||||
|
def NearestUpsample(dim, dim_out = None):
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
||||||
@@ -1256,20 +1260,6 @@ class CrossAttention(nn.Module):
|
|||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
class GridAttention(nn.Module):
|
|
||||||
def __init__(self, *args, window_size = 8, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.window_size = window_size
|
|
||||||
self.attn = Attention(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h, w = x.shape[-2:]
|
|
||||||
wsz = self.window_size
|
|
||||||
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
|
|
||||||
out = self.attn(x)
|
|
||||||
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class LinearAttention(nn.Module):
|
class LinearAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1369,6 +1359,8 @@ class Unet(nn.Module):
|
|||||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||||
memory_efficient = False,
|
memory_efficient = False,
|
||||||
scale_skip_connection = False,
|
scale_skip_connection = False,
|
||||||
|
nearest_upsample = False,
|
||||||
|
final_conv_kernel_size = 1,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1473,6 +1465,10 @@ class Unet(nn.Module):
|
|||||||
if cross_embed_downsample:
|
if cross_embed_downsample:
|
||||||
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
||||||
|
|
||||||
|
# upsample klass
|
||||||
|
|
||||||
|
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
||||||
|
|
||||||
# give memory efficient unet an initial resnet block
|
# give memory efficient unet an initial resnet block
|
||||||
|
|
||||||
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
|
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
|
||||||
@@ -1517,11 +1513,11 @@ class Unet(nn.Module):
|
|||||||
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
Upsample(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||||
self.to_out = nn.Conv2d(dim, self.channels_out, 3, padding = 1)
|
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
||||||
|
|
||||||
# if the current settings for the unet are not correct
|
# if the current settings for the unet are not correct
|
||||||
# for cascading DDPM, then reinit the unet with the right settings
|
# for cascading DDPM, then reinit the unet with the right settings
|
||||||
@@ -1716,7 +1712,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
x = self.mid_block2(x, t, mid_c)
|
x = self.mid_block2(x, t, mid_c)
|
||||||
|
|
||||||
connect_skip = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||||
|
|
||||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||||
x = connect_skip(x)
|
x = connect_skip(x)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.14.1'
|
__version__ = '0.15.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user