diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 862440e..4626080 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1926,7 +1926,7 @@ class Unet(nn.Module): hiddens.append(x) x = attn(x) - hiddens.append(x) + hiddens.append(x.contiguous()) if exists(post_downsample): x = post_downsample(x) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 7fe1493..d1d123f 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.24.1' +__version__ = '0.24.2'