mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-24 00:34:20 +01:00
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import math
|
||||
from inspect import isfunction
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@@ -7,8 +8,6 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from ..util import exists, default
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -59,11 +58,25 @@ except:
|
||||
from .diffusionmodules.util import checkpoint
|
||||
|
||||
|
||||
def uniq(arr): # TODO: this seems unused
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def init_(tensor): # TODO: this seems unused
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
@@ -243,6 +256,23 @@ class CrossAttention(nn.Module):
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
## old
|
||||
"""
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim, v)
|
||||
"""
|
||||
## new
|
||||
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
||||
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
||||
out = F.scaled_dot_product_attention(
|
||||
|
||||
Reference in New Issue
Block a user