mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 06:44:22 +01:00
This reverts commit 6f6d3f8716.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
@@ -33,8 +32,6 @@ from ...util import (
|
||||
instantiate_from_config,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractEmbModel(nn.Module):
|
||||
def __init__(self):
|
||||
@@ -99,7 +96,7 @@ class GeneralConditioner(nn.Module):
|
||||
for param in embedder.parameters():
|
||||
param.requires_grad = False
|
||||
embedder.eval()
|
||||
logger.debug(
|
||||
print(
|
||||
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
|
||||
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
|
||||
)
|
||||
@@ -730,7 +727,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
|
||||
)
|
||||
if tokens is not None:
|
||||
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
|
||||
logger.warning(
|
||||
print(
|
||||
f"You are running very experimental token-concat in {self.__class__.__name__}. "
|
||||
f"Check what you are doing, and then remove this message."
|
||||
)
|
||||
@@ -756,7 +753,7 @@ class FrozenCLIPT5Encoder(AbstractEmbModel):
|
||||
clip_version, device, max_length=clip_max_length
|
||||
)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||
logger.debug(
|
||||
print(
|
||||
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
|
||||
)
|
||||
@@ -798,7 +795,7 @@ class SpatialRescaler(nn.Module):
|
||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||
self.remap_output = out_channels is not None or remap_output
|
||||
if self.remap_output:
|
||||
logger.debug(
|
||||
print(
|
||||
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
|
||||
)
|
||||
self.channel_mapper = nn.Conv2d(
|
||||
|
||||
Reference in New Issue
Block a user