mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Very simple pyo3 bindings for candle.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -18,3 +18,4 @@ Cargo.lock
|
|||||||
|
|
||||||
perf.data
|
perf.data
|
||||||
flamegraph.svg
|
flamegraph.svg
|
||||||
|
*.so
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"candle-core",
|
"candle-core",
|
||||||
"candle-hub",
|
|
||||||
"candle-kernels",
|
"candle-kernels",
|
||||||
|
"candle-hub",
|
||||||
|
"candle-pyo3",
|
||||||
]
|
]
|
||||||
|
19
candle-pyo3/Cargo.toml
Normal file
19
candle-pyo3/Cargo.toml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
[package]
|
||||||
|
name = "candle-pyo3"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
description = "PyO3 bindings for the candle ML framework."
|
||||||
|
repository = "https://github.com/LaurentMazare/candle"
|
||||||
|
keywords = ["blas", "tensor", "machine-learning"]
|
||||||
|
categories = ["science"]
|
||||||
|
license = "MIT/Apache-2.0"
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "candle"
|
||||||
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
candle = { path = "../candle-core", default-features=false }
|
||||||
|
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
4
candle-pyo3/README.md
Normal file
4
candle-pyo3/README.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
```
|
||||||
|
cargo build --release --package candle-pyo3 --no-default-features && cp -f target/release/libcandle.so candle.so
|
||||||
|
PYTHONPATH=. python3 candle-pyo3/test.py
|
||||||
|
```
|
46
candle-pyo3/src/lib.rs
Normal file
46
candle-pyo3/src/lib.rs
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::{
|
||||||
|
exceptions::PyValueError,
|
||||||
|
};
|
||||||
|
|
||||||
|
use ::candle::{Tensor, Device::Cpu};
|
||||||
|
|
||||||
|
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||||
|
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyclass(name="Tensor")]
|
||||||
|
struct PyTensor(Tensor);
|
||||||
|
|
||||||
|
impl std::ops::Deref for PyTensor {
|
||||||
|
type Target = Tensor;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl PyTensor {
|
||||||
|
#[new]
|
||||||
|
fn new(f: f32) -> PyResult<Self> {
|
||||||
|
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __repr__(&self) -> String {
|
||||||
|
format!("{}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
fn add(tensor: &PyTensor, f: f64) -> PyResult<PyTensor> {
|
||||||
|
let tensor = (&tensor.0 + f).map_err(wrap_err)?;
|
||||||
|
Ok(PyTensor(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymodule]
|
||||||
|
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
|
m.add_class::<PyTensor>()?;
|
||||||
|
m.add_function(wrap_pyfunction!(add, m)?)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
5
candle-pyo3/test.py
Normal file
5
candle-pyo3/test.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
import candle
|
||||||
|
|
||||||
|
t = candle.Tensor(42.0)
|
||||||
|
print(t)
|
||||||
|
print(candle.add(t, 3.14))
|
Reference in New Issue
Block a user