Very simple pyo3 bindings for candle.

This commit is contained in:
laurent
2023-07-01 20:36:44 +01:00
parent dd879f5b67
commit ebb0fedf14
6 changed files with 77 additions and 1 deletions

1
.gitignore vendored
View File

@ -18,3 +18,4 @@ Cargo.lock
perf.data
flamegraph.svg
*.so

View File

@ -1,6 +1,7 @@
[workspace]
members = [
"candle-core",
"candle-hub",
"candle-kernels",
"candle-hub",
"candle-pyo3",
]

19
candle-pyo3/Cargo.toml Normal file
View File

@ -0,0 +1,19 @@
[package]
name = "candle-pyo3"
version = "0.1.0"
edition = "2021"
description = "PyO3 bindings for the candle ML framework."
repository = "https://github.com/LaurentMazare/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
readme = "README.md"
[lib]
name = "candle"
crate-type = ["cdylib"]
[dependencies]
candle = { path = "../candle-core", default-features=false }
pyo3 = { version = "0.19.0", features = ["extension-module"] }

4
candle-pyo3/README.md Normal file
View File

@ -0,0 +1,4 @@
```
cargo build --release --package candle-pyo3 --no-default-features && cp -f target/release/libcandle.so candle.so
PYTHONPATH=. python3 candle-pyo3/test.py
```

46
candle-pyo3/src/lib.rs Normal file
View File

@ -0,0 +1,46 @@
use pyo3::prelude::*;
use pyo3::{
exceptions::PyValueError,
};
use ::candle::{Tensor, Device::Cpu};
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[pyclass(name="Tensor")]
struct PyTensor(Tensor);
impl std::ops::Deref for PyTensor {
type Target = Tensor;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[pymethods]
impl PyTensor {
#[new]
fn new(f: f32) -> PyResult<Self> {
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
}
fn __repr__(&self) -> String {
format!("{}", self.0)
}
}
#[pyfunction]
fn add(tensor: &PyTensor, f: f64) -> PyResult<PyTensor> {
let tensor = (&tensor.0 + f).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pymodule]
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyTensor>()?;
m.add_function(wrap_pyfunction!(add, m)?)?;
Ok(())
}

5
candle-pyo3/test.py Normal file
View File

@ -0,0 +1,5 @@
import candle
t = candle.Tensor(42.0)
print(t)
print(candle.add(t, 3.14))