mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add support for accelerate in the pyo3 bindings. (#1167)
This commit is contained in:
@ -14,16 +14,18 @@ name = "candle"
|
|||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
|
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
pyo3-build-config = "0.19"
|
pyo3-build-config = "0.19"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||||
cuda = ["candle/cuda"]
|
cuda = ["candle/cuda"]
|
||||||
mkl = ["dep:intel-mkl-src","candle/mkl"]
|
mkl = ["dep:intel-mkl-src","candle/mkl"]
|
||||||
|
@ -11,6 +11,9 @@ use half::{bf16, f16};
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||||
|
|
||||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
import candle
|
import candle
|
||||||
|
|
||||||
|
print(f"mkl: {candle.utils.has_mkl()}")
|
||||||
|
print(f"accelerate: {candle.utils.has_accelerate()}")
|
||||||
|
print(f"num-threads: {candle.utils.get_num_threads()}")
|
||||||
|
print(f"cuda: {candle.utils.cuda_is_available()}")
|
||||||
|
|
||||||
t = candle.Tensor(42.0)
|
t = candle.Tensor(42.0)
|
||||||
print(t)
|
print(t)
|
||||||
print(t.shape, t.rank, t.device)
|
print(t.shape, t.rank, t.device)
|
||||||
|
Reference in New Issue
Block a user