diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index e22a073..005a920 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1490,7 +1490,8 @@ class LinearAttention(nn.Module): self, dim, dim_head = 32, - heads = 8 + heads = 8, + **kwargs ): super().__init__() self.scale = dim_head ** -0.5 diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 4e7c72a..9e0feee 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.4.3' +__version__ = '1.4.4'