mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
remove einops exts for better pytorch 2.0 compile compatibility
This commit is contained in:
@@ -11,8 +11,7 @@ import torch.nn.functional as F
|
||||
from torch.autograd import grad as torch_grad
|
||||
import torchvision
|
||||
|
||||
from einops import rearrange, reduce, repeat
|
||||
from einops_exts import rearrange_many
|
||||
from einops import rearrange, reduce, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# constants
|
||||
@@ -408,7 +407,7 @@ class Attention(nn.Module):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
q = q * self.scale
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
Reference in New Issue
Block a user