mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
remove einops exts for better pytorch 2.0 compile compatibility
This commit is contained in:
@@ -12,10 +12,8 @@ from torch.utils.checkpoint import checkpoint
|
|||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
|
||||||
from einops import rearrange, repeat, reduce
|
from einops import rearrange, repeat, reduce, pack, unpack
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
|
||||||
from einops_exts.torch import EinopsToAndFrom
|
|
||||||
|
|
||||||
from kornia.filters import gaussian_blur2d
|
from kornia.filters import gaussian_blur2d
|
||||||
import kornia.augmentation as K
|
import kornia.augmentation as K
|
||||||
@@ -669,6 +667,23 @@ class NoiseScheduler(nn.Module):
|
|||||||
return loss
|
return loss
|
||||||
return loss * extract(self.p2_loss_weight, times, loss.shape)
|
return loss * extract(self.p2_loss_weight, times, loss.shape)
|
||||||
|
|
||||||
|
# rearrange image to sequence
|
||||||
|
|
||||||
|
class RearrangeToSequence(nn.Module):
|
||||||
|
def __init__(self, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = rearrange(x, 'b c ... -> b ... c')
|
||||||
|
x, ps = pack([x], 'b * c')
|
||||||
|
|
||||||
|
x = self.fn(x)
|
||||||
|
|
||||||
|
x, = unpack(x, ps, 'b * c')
|
||||||
|
x = rearrange(x, 'b ... c -> b c ...')
|
||||||
|
return x
|
||||||
|
|
||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
@@ -867,7 +882,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# add null key / value for classifier free guidance in prior net
|
# add null key / value for classifier free guidance in prior net
|
||||||
|
|
||||||
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
|
nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
|
||||||
k = torch.cat((nk, k), dim = -2)
|
k = torch.cat((nk, k), dim = -2)
|
||||||
v = torch.cat((nv, v), dim = -2)
|
v = torch.cat((nv, v), dim = -2)
|
||||||
|
|
||||||
@@ -1629,14 +1644,10 @@ class ResnetBlock(nn.Module):
|
|||||||
self.cross_attn = None
|
self.cross_attn = None
|
||||||
|
|
||||||
if exists(cond_dim):
|
if exists(cond_dim):
|
||||||
self.cross_attn = EinopsToAndFrom(
|
self.cross_attn = CrossAttention(
|
||||||
'b c h w',
|
dim = dim_out,
|
||||||
'b (h w) c',
|
context_dim = cond_dim,
|
||||||
CrossAttention(
|
cosine_sim = cosine_sim_cross_attn
|
||||||
dim = dim_out,
|
|
||||||
context_dim = cond_dim,
|
|
||||||
cosine_sim = cosine_sim_cross_attn
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
|
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
|
||||||
@@ -1655,8 +1666,15 @@ class ResnetBlock(nn.Module):
|
|||||||
|
|
||||||
if exists(self.cross_attn):
|
if exists(self.cross_attn):
|
||||||
assert exists(cond)
|
assert exists(cond)
|
||||||
|
|
||||||
|
h = rearrange(h, 'b c ... -> b ... c')
|
||||||
|
h, ps = pack([h], 'b * c')
|
||||||
|
|
||||||
h = self.cross_attn(h, context = cond) + h
|
h = self.cross_attn(h, context = cond) + h
|
||||||
|
|
||||||
|
h, = unpack(h, ps, 'b * c')
|
||||||
|
h = rearrange(h, 'b ... c -> b c ...')
|
||||||
|
|
||||||
h = self.block2(h)
|
h = self.block2(h)
|
||||||
return h + self.res_conv(x)
|
return h + self.res_conv(x)
|
||||||
|
|
||||||
@@ -1702,11 +1720,11 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||||
|
|
||||||
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
|
||||||
|
|
||||||
# add null key / value for classifier free guidance in prior net
|
# add null key / value for classifier free guidance in prior net
|
||||||
|
|
||||||
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
|
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
|
||||||
|
|
||||||
k = torch.cat((nk, k), dim = -2)
|
k = torch.cat((nk, k), dim = -2)
|
||||||
v = torch.cat((nv, v), dim = -2)
|
v = torch.cat((nv, v), dim = -2)
|
||||||
@@ -1759,7 +1777,7 @@ class LinearAttention(nn.Module):
|
|||||||
|
|
||||||
fmap = self.norm(fmap)
|
fmap = self.norm(fmap)
|
||||||
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
|
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
|
||||||
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
|
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
|
||||||
|
|
||||||
q = q.softmax(dim = -1)
|
q = q.softmax(dim = -1)
|
||||||
k = k.softmax(dim = -2)
|
k = k.softmax(dim = -2)
|
||||||
@@ -1993,7 +2011,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self_attn = cast_tuple(self_attn, num_stages)
|
self_attn = cast_tuple(self_attn, num_stages)
|
||||||
|
|
||||||
create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs)))
|
create_self_attn = lambda dim: RearrangeToSequence(Residual(Attention(dim, **attn_kwargs)))
|
||||||
|
|
||||||
# resnet block klass
|
# resnet block klass
|
||||||
|
|
||||||
@@ -3230,7 +3248,7 @@ class Decoder(nn.Module):
|
|||||||
learned_variance = self.learned_variance[unet_index]
|
learned_variance = self.learned_variance[unet_index]
|
||||||
b, c, h, w, device, = *image.shape, image.device
|
b, c, h, w, device, = *image.shape, image.device
|
||||||
|
|
||||||
check_shape(image, 'b c h w', c = self.channels)
|
assert image.shape[1] == self.channels
|
||||||
assert h >= target_image_size and w >= target_image_size
|
assert h >= target_image_size and w >= target_image_size
|
||||||
|
|
||||||
times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
|
times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.12.4'
|
__version__ = '1.14.0'
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ import torch.nn.functional as F
|
|||||||
from torch.autograd import grad as torch_grad
|
from torch.autograd import grad as torch_grad
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
from einops import rearrange, reduce, repeat
|
from einops import rearrange, reduce, repeat, pack, unpack
|
||||||
from einops_exts import rearrange_many
|
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
@@ -408,7 +407,7 @@ class Attention(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||||
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -30,8 +30,7 @@ setup(
|
|||||||
'clip-anytorch>=2.5.2',
|
'clip-anytorch>=2.5.2',
|
||||||
'coca-pytorch>=0.0.5',
|
'coca-pytorch>=0.0.5',
|
||||||
'ema-pytorch>=0.0.7',
|
'ema-pytorch>=0.0.7',
|
||||||
'einops>=0.4',
|
'einops>=0.6',
|
||||||
'einops-exts>=0.0.3',
|
|
||||||
'embedding-reader',
|
'embedding-reader',
|
||||||
'kornia>=0.5.4',
|
'kornia>=0.5.4',
|
||||||
'numpy',
|
'numpy',
|
||||||
|
|||||||
Reference in New Issue
Block a user