From 0069857cf838ee2d1ae874df7829cb4324ac6127 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 20 Apr 2023 07:05:29 -0700 Subject: [PATCH] remove einops exts for better pytorch 2.0 compile compatibility --- dalle2_pytorch/dalle2_pytorch.py | 52 +++++++++++++++++++++----------- dalle2_pytorch/version.py | 2 +- dalle2_pytorch/vqgan_vae.py | 5 ++- setup.py | 3 +- 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c0211fc..71a6e4c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -12,10 +12,8 @@ from torch.utils.checkpoint import checkpoint from torch import nn, einsum import torchvision.transforms as T -from einops import rearrange, repeat, reduce +from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange -from einops_exts import rearrange_many, repeat_many, check_shape -from einops_exts.torch import EinopsToAndFrom from kornia.filters import gaussian_blur2d import kornia.augmentation as K @@ -669,6 +667,23 @@ class NoiseScheduler(nn.Module): return loss return loss * extract(self.p2_loss_weight, times, loss.shape) +# rearrange image to sequence + +class RearrangeToSequence(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + x = rearrange(x, 'b c ... -> b ... c') + x, ps = pack([x], 'b * c') + + x = self.fn(x) + + x, = unpack(x, ps, 'b * c') + x = rearrange(x, 'b ... c -> b c ...') + return x + # diffusion prior class LayerNorm(nn.Module): @@ -867,7 +882,7 @@ class Attention(nn.Module): # add null key / value for classifier free guidance in prior net - nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b) + nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) @@ -1629,14 +1644,10 @@ class ResnetBlock(nn.Module): self.cross_attn = None if exists(cond_dim): - self.cross_attn = EinopsToAndFrom( - 'b c h w', - 'b (h w) c', - CrossAttention( - dim = dim_out, - context_dim = cond_dim, - cosine_sim = cosine_sim_cross_attn - ) + self.cross_attn = CrossAttention( + dim = dim_out, + context_dim = cond_dim, + cosine_sim = cosine_sim_cross_attn ) self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization) @@ -1655,8 +1666,15 @@ class ResnetBlock(nn.Module): if exists(self.cross_attn): assert exists(cond) + + h = rearrange(h, 'b c ... -> b ... c') + h, ps = pack([h], 'b * c') + h = self.cross_attn(h, context = cond) + h + h, = unpack(h, ps, 'b * c') + h = rearrange(h, 'b ... c -> b c ...') + h = self.block2(h) return h + self.res_conv(x) @@ -1702,11 +1720,11 @@ class CrossAttention(nn.Module): q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) - q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) # add null key / value for classifier free guidance in prior net - nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b) + nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) @@ -1759,7 +1777,7 @@ class LinearAttention(nn.Module): fmap = self.norm(fmap) q, k, v = self.to_qkv(fmap).chunk(3, dim = 1) - q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v)) q = q.softmax(dim = -1) k = k.softmax(dim = -2) @@ -1993,7 +2011,7 @@ class Unet(nn.Module): self_attn = cast_tuple(self_attn, num_stages) - create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs))) + create_self_attn = lambda dim: RearrangeToSequence(Residual(Attention(dim, **attn_kwargs))) # resnet block klass @@ -3230,7 +3248,7 @@ class Decoder(nn.Module): learned_variance = self.learned_variance[unet_index] b, c, h, w, device, = *image.shape, image.device - check_shape(image, 'b c h w', c = self.channels) + assert image.shape[1] == self.channels assert h >= target_image_size and w >= target_image_size times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index aba17a1..e4f2ad4 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.12.4' +__version__ = '1.14.0' diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index 1c073fd..01d7258 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -11,8 +11,7 @@ import torch.nn.functional as F from torch.autograd import grad as torch_grad import torchvision -from einops import rearrange, reduce, repeat -from einops_exts import rearrange_many +from einops import rearrange, reduce, repeat, pack, unpack from einops.layers.torch import Rearrange # constants @@ -408,7 +407,7 @@ class Attention(nn.Module): x = self.norm(x) q, k, v = self.to_qkv(x).chunk(3, dim = -1) - q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) q = q * self.scale sim = einsum('b h i d, b h j d -> b h i j', q, k) diff --git a/setup.py b/setup.py index ddbce11..cf94e63 100644 --- a/setup.py +++ b/setup.py @@ -30,8 +30,7 @@ setup( 'clip-anytorch>=2.5.2', 'coca-pytorch>=0.0.5', 'ema-pytorch>=0.0.7', - 'einops>=0.4', - 'einops-exts>=0.0.3', + 'einops>=0.6', 'embedding-reader', 'kornia>=0.5.4', 'numpy',