mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-23 06:34:27 +01:00
Test model device manager and fix bugs
This commit is contained in:
51
tests/inference/test_modelmanager.py
Normal file
51
tests/inference/test_modelmanager.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import numpy
|
||||
from PIL import Image
|
||||
import pytest
|
||||
from pytest import fixture
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from sgm.inference.api import (
|
||||
model_specs,
|
||||
SamplingParams,
|
||||
SamplingPipeline,
|
||||
Sampler,
|
||||
ModelArchitecture,
|
||||
)
|
||||
import sgm.inference.helpers as helpers
|
||||
|
||||
def get_torch_device(model: torch.nn.Module) -> torch.device:
|
||||
param = next(model.parameters(), None)
|
||||
if param is not None:
|
||||
return param.device
|
||||
else:
|
||||
buf = next(model.buffers(), None)
|
||||
if buf is not None:
|
||||
return buf.device
|
||||
else:
|
||||
raise TypeError("Could not determine device of input model")
|
||||
|
||||
|
||||
@pytest.mark.inference
|
||||
def test_default_loading():
|
||||
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1)
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
with pipeline.device_manager.use(pipeline.model.model):
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
with pipeline.device_manager.use(pipeline.model.conditioner):
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
|
||||
@pytest.mark.inference
|
||||
def test_model_swapping():
|
||||
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu"))
|
||||
assert get_torch_device(pipeline.model.model).type == "cpu"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
|
||||
with pipeline.device_manager.use(pipeline.model.model):
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.model).type == "cpu"
|
||||
with pipeline.device_manager.use(pipeline.model.conditioner):
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
|
||||
Reference in New Issue
Block a user