diff --git a/.gitignore b/.gitignore index 33593c9b..2997bb61 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ Cargo.lock perf.data flamegraph.svg +*.so diff --git a/Cargo.toml b/Cargo.toml index 5413a592..7ecf17bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "candle-core", - "candle-hub", "candle-kernels", + "candle-hub", + "candle-pyo3", ] diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml new file mode 100644 index 00000000..75a52b4d --- /dev/null +++ b/candle-pyo3/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "candle-pyo3" +version = "0.1.0" +edition = "2021" + +description = "PyO3 bindings for the candle ML framework." +repository = "https://github.com/LaurentMazare/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT/Apache-2.0" +readme = "README.md" + +[lib] +name = "candle" +crate-type = ["cdylib"] + +[dependencies] +candle = { path = "../candle-core", default-features=false } +pyo3 = { version = "0.19.0", features = ["extension-module"] } diff --git a/candle-pyo3/README.md b/candle-pyo3/README.md new file mode 100644 index 00000000..100a74d2 --- /dev/null +++ b/candle-pyo3/README.md @@ -0,0 +1,4 @@ +``` +cargo build --release --package candle-pyo3 --no-default-features && cp -f target/release/libcandle.so candle.so +PYTHONPATH=. python3 candle-pyo3/test.py +``` diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs new file mode 100644 index 00000000..be689559 --- /dev/null +++ b/candle-pyo3/src/lib.rs @@ -0,0 +1,46 @@ +use pyo3::prelude::*; +use pyo3::{ + exceptions::PyValueError, +}; + +use ::candle::{Tensor, Device::Cpu}; + +pub fn wrap_err(err: ::candle::Error) -> PyErr { + PyErr::new::(format!("{err:?}")) +} + +#[pyclass(name="Tensor")] +struct PyTensor(Tensor); + +impl std::ops::Deref for PyTensor { + type Target = Tensor; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[pymethods] +impl PyTensor { + #[new] + fn new(f: f32) -> PyResult { + Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) + } + + fn __repr__(&self) -> String { + format!("{}", self.0) + } +} + +#[pyfunction] +fn add(tensor: &PyTensor, f: f64) -> PyResult { + let tensor = (&tensor.0 + f).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pymodule] +fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_function(wrap_pyfunction!(add, m)?)?; + Ok(()) +} diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py new file mode 100644 index 00000000..b8a2818e --- /dev/null +++ b/candle-pyo3/test.py @@ -0,0 +1,5 @@ +import candle + +t = candle.Tensor(42.0) +print(t) +print(candle.add(t, 3.14))