mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-01-02 21:24:25 +01:00
This reverts commit 6f6d3f8716.
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@@ -9,8 +8,6 @@ import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
@@ -18,7 +15,7 @@ try:
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
logger.debug("no module 'xformers'. Processing without...")
|
||||
print("no module 'xformers'. Processing without...")
|
||||
|
||||
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
|
||||
|
||||
@@ -291,14 +288,12 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
)
|
||||
attn_type = "vanilla-xformers"
|
||||
logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
logger.debug(
|
||||
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
|
||||
)
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
@@ -638,8 +633,10 @@ class Decoder(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logger.debug(
|
||||
f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions."
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
make_attn_cls = self._make_attn()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
@@ -22,8 +21,6 @@ from ...modules.diffusionmodules.util import (
|
||||
)
|
||||
from ...util import default, exists
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# dummy replace
|
||||
def convert_module_to_f16(x):
|
||||
@@ -180,13 +177,13 @@ class Downsample(nn.Module):
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
|
||||
if use_conv:
|
||||
logger.debug(
|
||||
f"Building a Downsample layer with {dims} dims.\n"
|
||||
print(f"Building a Downsample layer with {dims} dims.")
|
||||
print(
|
||||
f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
|
||||
f"kernel-size: 3, stride: {stride}, padding: {padding}"
|
||||
)
|
||||
if dims == 3:
|
||||
logger.debug(f" --> Downsampling third axis (time): {third_down}")
|
||||
print(f" --> Downsampling third axis (time): {third_down}")
|
||||
self.op = conv_nd(
|
||||
dims,
|
||||
self.channels,
|
||||
@@ -273,7 +270,7 @@ class ResBlock(TimestepBlock):
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels
|
||||
)
|
||||
if self.skip_t_emb:
|
||||
logger.debug(f"Skipping timestep embedding in {self.__class__.__name__}")
|
||||
print(f"Skipping timestep embedding in {self.__class__.__name__}")
|
||||
assert not self.use_scale_shift_norm
|
||||
self.emb_layers = None
|
||||
self.exchange_temb_dims = False
|
||||
@@ -622,12 +619,12 @@ class UNetModel(nn.Module):
|
||||
range(len(num_attention_blocks)),
|
||||
)
|
||||
)
|
||||
logger.warning(
|
||||
print(
|
||||
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set."
|
||||
)
|
||||
) # todo: convert to warning
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
@@ -636,7 +633,7 @@ class UNetModel(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
if use_fp16:
|
||||
logger.warning("use_fp16 was dropped and has no effect anymore.")
|
||||
print("WARNING: use_fp16 was dropped and has no effect anymore.")
|
||||
# self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
@@ -667,7 +664,7 @@ class UNetModel(nn.Module):
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
logger.debug("setting up linear c_adm embedding layer")
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "timestep":
|
||||
self.label_emb = checkpoint_wrapper_fn(
|
||||
|
||||
Reference in New Issue
Block a user