Fix loading safetensors with load_model_from_config

Previously, the `sd` from the safetensors if branch was not used at all, and `pl_sd` would have not been assigned.
This commit is contained in:
Aarni Koskela
2023-07-17 09:56:35 +03:00
parent 5c10deee76
commit 48904a692d

View File

@@ -212,7 +212,6 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
raise NotImplementedError raise NotImplementedError
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
sd = pl_sd["state_dict"]
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)