mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 14:54:21 +01:00
This reverts commit 6f6d3f8716.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
@@ -22,8 +21,6 @@ from ...modules.diffusionmodules.util import (
|
||||
)
|
||||
from ...util import default, exists
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# dummy replace
|
||||
def convert_module_to_f16(x):
|
||||
@@ -180,13 +177,13 @@ class Downsample(nn.Module):
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
|
||||
if use_conv:
|
||||
logger.debug(
|
||||
f"Building a Downsample layer with {dims} dims.\n"
|
||||
print(f"Building a Downsample layer with {dims} dims.")
|
||||
print(
|
||||
f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
|
||||
f"kernel-size: 3, stride: {stride}, padding: {padding}"
|
||||
)
|
||||
if dims == 3:
|
||||
logger.debug(f" --> Downsampling third axis (time): {third_down}")
|
||||
print(f" --> Downsampling third axis (time): {third_down}")
|
||||
self.op = conv_nd(
|
||||
dims,
|
||||
self.channels,
|
||||
@@ -273,7 +270,7 @@ class ResBlock(TimestepBlock):
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels
|
||||
)
|
||||
if self.skip_t_emb:
|
||||
logger.debug(f"Skipping timestep embedding in {self.__class__.__name__}")
|
||||
print(f"Skipping timestep embedding in {self.__class__.__name__}")
|
||||
assert not self.use_scale_shift_norm
|
||||
self.emb_layers = None
|
||||
self.exchange_temb_dims = False
|
||||
@@ -622,12 +619,12 @@ class UNetModel(nn.Module):
|
||||
range(len(num_attention_blocks)),
|
||||
)
|
||||
)
|
||||
logger.warning(
|
||||
print(
|
||||
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set."
|
||||
)
|
||||
) # todo: convert to warning
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
@@ -636,7 +633,7 @@ class UNetModel(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
if use_fp16:
|
||||
logger.warning("use_fp16 was dropped and has no effect anymore.")
|
||||
print("WARNING: use_fp16 was dropped and has no effect anymore.")
|
||||
# self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
@@ -667,7 +664,7 @@ class UNetModel(nn.Module):
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
logger.debug("setting up linear c_adm embedding layer")
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "timestep":
|
||||
self.label_emb = checkpoint_wrapper_fn(
|
||||
|
||||
Reference in New Issue
Block a user