mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Very simple pyo3 bindings for candle.
This commit is contained in:
46
candle-pyo3/src/lib.rs
Normal file
46
candle-pyo3/src/lib.rs
Normal file
@ -0,0 +1,46 @@
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::{
|
||||
exceptions::PyValueError,
|
||||
};
|
||||
|
||||
use ::candle::{Tensor, Device::Cpu};
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
}
|
||||
|
||||
#[pyclass(name="Tensor")]
|
||||
struct PyTensor(Tensor);
|
||||
|
||||
impl std::ops::Deref for PyTensor {
|
||||
type Target = Tensor;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyTensor {
|
||||
#[new]
|
||||
fn new(f: f32) -> PyResult<Self> {
|
||||
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn add(tensor: &PyTensor, f: f64) -> PyResult<PyTensor> {
|
||||
let tensor = (&tensor.0 + f).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyTensor>()?;
|
||||
m.add_function(wrap_pyfunction!(add, m)?)?;
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user