diff --git a/README.md b/README.md index 8d89ab6..24f8cf1 100644 --- a/README.md +++ b/README.md @@ -1112,15 +1112,6 @@ Once built, images will be saved to the same directory the command is invoked } ``` -```bibtex -@inproceedings{Tu2022MaxViTMV, - title = {MaxViT: Multi-Axis Vision Transformer}, - author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li}, - year = {2022}, - url = {https://arxiv.org/abs/2204.01697} -} -``` - ```bibtex @article{Yu2021VectorquantizedIM, title = {Vector-quantized Image Modeling with Improved VQGAN}, diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index f3f11fa..0eb9e36 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1093,7 +1093,11 @@ class DiffusionPrior(nn.Module): # decoder -def Upsample(dim, dim_out = None): +def ConvTransposeUpsample(dim, dim_out = None): + dim_out = default(dim_out, dim) + return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1) + +def NearestUpsample(dim, dim_out = None): dim_out = default(dim_out, dim) return nn.Sequential( nn.Upsample(scale_factor = 2, mode = 'nearest'), @@ -1256,20 +1260,6 @@ class CrossAttention(nn.Module): out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) -class GridAttention(nn.Module): - def __init__(self, *args, window_size = 8, **kwargs): - super().__init__() - self.window_size = window_size - self.attn = Attention(*args, **kwargs) - - def forward(self, x): - h, w = x.shape[-2:] - wsz = self.window_size - x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz) - out = self.attn(x) - out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz) - return out - class LinearAttention(nn.Module): def __init__( self, @@ -1369,6 +1359,8 @@ class Unet(nn.Module): cross_embed_downsample_kernel_sizes = (2, 4), memory_efficient = False, scale_skip_connection = False, + nearest_upsample = False, + final_conv_kernel_size = 1, **kwargs ): super().__init__() @@ -1473,6 +1465,10 @@ class Unet(nn.Module): if cross_embed_downsample: downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) + # upsample klass + + upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample + # give memory efficient unet an initial resnet block self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None @@ -1517,11 +1513,11 @@ class Unet(nn.Module): ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), - Upsample(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity() + upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity() ])) self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) - self.to_out = nn.Conv2d(dim, self.channels_out, 3, padding = 1) + self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2) # if the current settings for the unet are not correct # for cascading DDPM, then reinit the unet with the right settings @@ -1716,7 +1712,7 @@ class Unet(nn.Module): x = self.mid_block2(x, t, mid_c) - connect_skip = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) + connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1) for init_block, sparse_attn, resnet_blocks, upsample in self.ups: x = connect_skip(x) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 092052c..a842d05 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.14.1' +__version__ = '0.15.0'