mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
PyO3: Add mkl
support (#1159)
* Add `mkl` support * Set `mkl` path on linux
This commit is contained in:
@ -18,6 +18,7 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
half = { workspace = true }
|
||||
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = "0.19"
|
||||
@ -25,3 +26,4 @@ pyo3-build-config = "0.19"
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src","candle/mkl"]
|
||||
|
@ -3,11 +3,14 @@ import logging
|
||||
try:
|
||||
from .candle import *
|
||||
except ImportError as e:
|
||||
# If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here
|
||||
logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...")
|
||||
# If we are in development mode, or we did not bundle the DLLs, we try to locate them here
|
||||
# PyO3 wont give us any infomration about what DLLs are missing, so we can only try to load the DLLs and re-import the module
|
||||
logging.warning("DLLs were not bundled with this package. Trying to locate them...")
|
||||
import os
|
||||
import platform
|
||||
|
||||
def locate_cuda_dlls():
|
||||
logging.warning("Locating CUDA DLLs...")
|
||||
# Try to locate CUDA_PATH environment variable
|
||||
cuda_path = os.environ.get("CUDA_PATH", None)
|
||||
if cuda_path:
|
||||
@ -19,11 +22,32 @@ except ImportError as e:
|
||||
|
||||
logging.warning(f"Adding {cuda_path} to DLL search path...")
|
||||
os.add_dll_directory(cuda_path)
|
||||
else:
|
||||
logging.warning("CUDA_PATH environment variable not found!")
|
||||
|
||||
def locate_mkl_dlls():
|
||||
# Try to locate ONEAPI_ROOT environment variable
|
||||
oneapi_root = os.environ.get("ONEAPI_ROOT", None)
|
||||
if oneapi_root:
|
||||
if platform.system() == "Windows":
|
||||
mkl_path = os.path.join(
|
||||
oneapi_root, "compiler", "latest", "windows", "redist", "intel64_win", "compiler"
|
||||
)
|
||||
else:
|
||||
mkl_path = os.path.join(oneapi_root, "mkl", "latest", "lib", "intel64")
|
||||
|
||||
logging.warning(f"Adding {mkl_path} to DLL search path...")
|
||||
os.add_dll_directory(mkl_path)
|
||||
else:
|
||||
logging.warning("ONEAPI_ROOT environment variable not found!")
|
||||
|
||||
locate_cuda_dlls()
|
||||
locate_mkl_dlls()
|
||||
|
||||
try:
|
||||
from .candle import *
|
||||
except ImportError as inner_e:
|
||||
raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.")
|
||||
raise ImportError("Could not locate DLLs. Please check the documentation for more information.")
|
||||
|
||||
__doc__ = candle.__doc__
|
||||
if hasattr(candle, "__all__"):
|
||||
|
@ -8,6 +8,9 @@ use std::sync::Arc;
|
||||
|
||||
use half::{bf16, f16};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
|
Reference in New Issue
Block a user