Files
generative-models/tests/inference/test_modelmanager.py
2023-08-12 07:15:36 +00:00

51 lines
2.0 KiB
Python

import numpy
from PIL import Image
import pytest
from pytest import fixture
import torch
from typing import Tuple, Optional
from sgm.inference.api import (
model_specs,
SamplingParams,
SamplingPipeline,
Sampler,
ModelArchitecture,
)
import sgm.inference.helpers as helpers
def get_torch_device(model: torch.nn.Module) -> torch.device:
param = next(model.parameters(), None)
if param is not None:
return param.device
else:
buf = next(model.buffers(), None)
if buf is not None:
return buf.device
else:
raise TypeError("Could not determine device of input model")
@pytest.mark.inference
def test_default_loading():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1)
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cuda"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
@pytest.mark.inference
def test_model_swapping():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu"))
assert get_torch_device(pipeline.model.model).type == "cpu"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cpu"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"