mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 14:54:21 +01:00
fall back to vanilla if xformers is not available (#51)
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
# pytorch_diffusion + derived encoder decoder
|
# pytorch_diffusion + derived encoder decoder
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
@@ -291,6 +293,13 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
|||||||
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||||
)
|
)
|
||||||
attn_type = "vanilla-xformers"
|
attn_type = "vanilla-xformers"
|
||||||
|
if attn_type == "vanilla-xformers" and not XFORMERS_IS_AVAILABLE:
|
||||||
|
warnings.warn(
|
||||||
|
f"Requested attention type {attn_type!r} but Xformers is not available; "
|
||||||
|
f"falling back to vanilla attention"
|
||||||
|
)
|
||||||
|
attn_type = "vanilla"
|
||||||
|
attn_kwargs = None
|
||||||
logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||||
if attn_type == "vanilla":
|
if attn_type == "vanilla":
|
||||||
assert attn_kwargs is None
|
assert attn_kwargs is None
|
||||||
|
|||||||
Reference in New Issue
Block a user