diff --git a/sdkit/models/model_loader/stable_diffusion/__init__.py b/sdkit/models/model_loader/stable_diffusion/__init__.py index 63fd9c7..01eb0c4 100644 --- a/sdkit/models/model_loader/stable_diffusion/__init__.py +++ b/sdkit/models/model_loader/stable_diffusion/__init__.py @@ -125,7 +125,7 @@ def load_diffusers_model(context: Context, model_path, config_file_path, convert import platform from sdkit.generate.sampler import diffusers_samplers - from sdkit.utils import gc, has_amd_gpu + from sdkit.utils import gc, get_directml_device_id from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt from . import diffusers_bugfixes @@ -159,7 +159,7 @@ def load_diffusers_model(context: Context, model_path, config_file_path, convert unet_trt_path = model_component + ".unet.trt" unet_onnx_path = model_component + ".unet.onnx" - use_directml = platform.system() == "Windows" and has_amd_gpu() + use_directml = (get_directml_device_id() != None) try: from importlib.metadata import version diff --git a/sdkit/models/model_loader/stable_diffusion/accelerators.py b/sdkit/models/model_loader/stable_diffusion/accelerators.py index 2c36ada..32a8eac 100644 --- a/sdkit/models/model_loader/stable_diffusion/accelerators.py +++ b/sdkit/models/model_loader/stable_diffusion/accelerators.py @@ -2,7 +2,7 @@ from dataclasses import dataclass import numpy as np -from sdkit.utils import log +from sdkit.utils import log, get_directml_device_id """ Current issues: @@ -53,15 +53,7 @@ def __init__(self, onnx_path): # sess_options.add_free_dimension_override_by_name("encoder_hidden_states_batch", batch_size * 2) # sess_options.add_free_dimension_override_by_name("encoder_hidden_states_sequence", 77) - import wmi - - w = wmi.WMI() - device_id = 0 - for i, controller in enumerate(w.Win32_VideoController()): - device_name = controller.wmi_property("Name").value - if "AMD" in device_name and "Radeon" in device_name: - device_id = i - break + device_id = get_directml_device_id() log.info(f"Using DirectML device_id: {device_id}") sess = ort.InferenceSession( diff --git a/sdkit/utils/__init__.py b/sdkit/utils/__init__.py index 64dc99b..fb3fdd1 100644 --- a/sdkit/utils/__init__.py +++ b/sdkit/utils/__init__.py @@ -39,4 +39,7 @@ convert_pipeline_unet_to_onnx, convert_pipeline_unet_to_tensorrt, ) -from .device_utils import has_amd_gpu +from .device_utils import ( + has_amd_gpu, + get_directml_device_id +) diff --git a/sdkit/utils/device_utils.py b/sdkit/utils/device_utils.py index 0414dcf..2e559b2 100644 --- a/sdkit/utils/device_utils.py +++ b/sdkit/utils/device_utils.py @@ -1,5 +1,6 @@ import platform import subprocess +import wmi def has_amd_gpu(): @@ -18,3 +19,20 @@ def has_amd_gpu(): return False return False + +def get_directml_device_id(): + os_name = platform.system() + + if os_name != "Windows": + return None + + w = wmi.WMI() + device_id = None + for i, controller in enumerate(w.Win32_VideoController()): + device_name = controller.wmi_property("Name").value + if ("AMD" in device_name and "Radeon" in device_name) or ("Intel" in device_name and "Arc" in device_name): + device_id = i + break + + return device_id +