mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 22:34:22 +01:00
This reverts commit 6f6d3f8716.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
@@ -19,8 +18,6 @@ from ..util import (
|
||||
log_txt_as_img,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiffusionEngine(pl.LightningModule):
|
||||
def __init__(
|
||||
@@ -76,7 +73,7 @@ class DiffusionEngine(pl.LightningModule):
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
||||
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
self.disable_first_stage_autocast = disable_first_stage_autocast
|
||||
@@ -97,13 +94,13 @@ class DiffusionEngine(pl.LightningModule):
|
||||
raise NotImplementedError
|
||||
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
logger.info(
|
||||
print(
|
||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||
)
|
||||
if len(missing) > 0:
|
||||
logger.info(f"Missing Keys: {missing}")
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
logger.info(f"Unexpected Keys: {unexpected}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def _init_first_stage(self, config):
|
||||
model = instantiate_from_config(config).eval()
|
||||
@@ -182,14 +179,14 @@ class DiffusionEngine(pl.LightningModule):
|
||||
self.model_ema.store(self.model.parameters())
|
||||
self.model_ema.copy_to(self.model)
|
||||
if context is not None:
|
||||
logger.info(f"{context}: Switched to EMA weights")
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.model.parameters())
|
||||
if context is not None:
|
||||
logger.info(f"{context}: Restored training weights")
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
@@ -205,7 +202,7 @@ class DiffusionEngine(pl.LightningModule):
|
||||
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
logger.debug("Setting up LambdaLR scheduler...")
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
||||
|
||||
Reference in New Issue
Block a user