From ebb0fedf145f0e6d2b332e884bf62009dd33d477 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 1 Jul 2023 20:36:44 +0100 Subject: [PATCH 1/6] Very simple pyo3 bindings for candle. --- .gitignore | 1 + Cargo.toml | 3 ++- candle-pyo3/Cargo.toml | 19 +++++++++++++++++ candle-pyo3/README.md | 4 ++++ candle-pyo3/src/lib.rs | 46 ++++++++++++++++++++++++++++++++++++++++++ candle-pyo3/test.py | 5 +++++ 6 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 candle-pyo3/Cargo.toml create mode 100644 candle-pyo3/README.md create mode 100644 candle-pyo3/src/lib.rs create mode 100644 candle-pyo3/test.py diff --git a/.gitignore b/.gitignore index 33593c9b..2997bb61 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ Cargo.lock perf.data flamegraph.svg +*.so diff --git a/Cargo.toml b/Cargo.toml index 5413a592..7ecf17bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "candle-core", - "candle-hub", "candle-kernels", + "candle-hub", + "candle-pyo3", ] diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml new file mode 100644 index 00000000..75a52b4d --- /dev/null +++ b/candle-pyo3/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "candle-pyo3" +version = "0.1.0" +edition = "2021" + +description = "PyO3 bindings for the candle ML framework." +repository = "https://github.com/LaurentMazare/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT/Apache-2.0" +readme = "README.md" + +[lib] +name = "candle" +crate-type = ["cdylib"] + +[dependencies] +candle = { path = "../candle-core", default-features=false } +pyo3 = { version = "0.19.0", features = ["extension-module"] } diff --git a/candle-pyo3/README.md b/candle-pyo3/README.md new file mode 100644 index 00000000..100a74d2 --- /dev/null +++ b/candle-pyo3/README.md @@ -0,0 +1,4 @@ +``` +cargo build --release --package candle-pyo3 --no-default-features && cp -f target/release/libcandle.so candle.so +PYTHONPATH=. python3 candle-pyo3/test.py +``` diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs new file mode 100644 index 00000000..be689559 --- /dev/null +++ b/candle-pyo3/src/lib.rs @@ -0,0 +1,46 @@ +use pyo3::prelude::*; +use pyo3::{ + exceptions::PyValueError, +}; + +use ::candle::{Tensor, Device::Cpu}; + +pub fn wrap_err(err: ::candle::Error) -> PyErr { + PyErr::new::(format!("{err:?}")) +} + +#[pyclass(name="Tensor")] +struct PyTensor(Tensor); + +impl std::ops::Deref for PyTensor { + type Target = Tensor; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[pymethods] +impl PyTensor { + #[new] + fn new(f: f32) -> PyResult { + Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) + } + + fn __repr__(&self) -> String { + format!("{}", self.0) + } +} + +#[pyfunction] +fn add(tensor: &PyTensor, f: f64) -> PyResult { + let tensor = (&tensor.0 + f).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pymodule] +fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_function(wrap_pyfunction!(add, m)?)?; + Ok(()) +} diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py new file mode 100644 index 00000000..b8a2818e --- /dev/null +++ b/candle-pyo3/test.py @@ -0,0 +1,5 @@ +import candle + +t = candle.Tensor(42.0) +print(t) +print(candle.add(t, 3.14)) From 52db2a6849659f0029a6fd136c47e945d8eef50f Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 1 Jul 2023 20:37:28 +0100 Subject: [PATCH 2/6] Apply rustfmt. --- candle-pyo3/src/lib.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index be689559..be9e427d 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,15 +1,13 @@ +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::{ - exceptions::PyValueError, -}; -use ::candle::{Tensor, Device::Cpu}; +use ::candle::{Device::Cpu, Tensor}; pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) } -#[pyclass(name="Tensor")] +#[pyclass(name = "Tensor")] struct PyTensor(Tensor); impl std::ops::Deref for PyTensor { From 42d1a52d01f5f10e3a04257cb4612225f08e1321 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 1 Jul 2023 20:55:15 +0100 Subject: [PATCH 3/6] Add two methods. --- candle-pyo3/src/lib.rs | 14 ++++++++++++++ candle-pyo3/test.py | 1 + 2 files changed, 15 insertions(+) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index be9e427d..e1ce7f97 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -25,9 +25,23 @@ impl PyTensor { Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) } + #[getter] + fn shape(&self) -> Vec { + self.0.dims().to_vec() + } + + #[getter] + fn rank(&self) -> usize { + self.0.rank() + } + fn __repr__(&self) -> String { format!("{}", self.0) } + + fn __str__(&self) -> String { + self.__repr__() + } } #[pyfunction] diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index b8a2818e..21242b44 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -2,4 +2,5 @@ import candle t = candle.Tensor(42.0) print(t) +print("shape", t.shape, t.rank) print(candle.add(t, 3.14)) From fbbde5b02cc4a711aed609f887d1705d67b2fd20 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 1 Jul 2023 21:27:35 +0100 Subject: [PATCH 4/6] Add some binary operators. --- candle-pyo3/src/lib.rs | 18 +++++++++++++++++- candle-pyo3/test.py | 2 +- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index e1ce7f97..35de86c8 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,4 +1,4 @@ -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use ::candle::{Device::Cpu, Tensor}; @@ -7,6 +7,7 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) } +#[derive(Clone)] #[pyclass(name = "Tensor")] struct PyTensor(Tensor); @@ -42,6 +43,21 @@ impl PyTensor { fn __str__(&self) -> String { self.__repr__() } + + fn __add__(&self, rhs: &PyAny) -> PyResult { + let tensor = if let Ok(rhs) = rhs.extract::() { + (&self.0 + &rhs.0).map_err(wrap_err)? + } else if let Ok(rhs) = rhs.extract::() { + (&self.0 + rhs).map_err(wrap_err)? + } else { + Err(PyTypeError::new_err("unsupported for add"))? + }; + Ok(Self(tensor)) + } + + fn __radd__(&self, rhs: &PyAny) -> PyResult { + self.__add__(rhs) + } } #[pyfunction] diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 21242b44..0db3fab9 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -3,4 +3,4 @@ import candle t = candle.Tensor(42.0) print(t) print("shape", t.shape, t.rank) -print(candle.add(t, 3.14)) +print(t + t) From 86df4ad79c156b306a8c0402c6475c772ef0e366 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 1 Jul 2023 21:34:38 +0100 Subject: [PATCH 5/6] Get shape to return a tuple. --- candle-pyo3/src/lib.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 35de86c8..ab280a63 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,3 +1,4 @@ +use pyo3::types::PyTuple; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; @@ -27,8 +28,8 @@ impl PyTensor { } #[getter] - fn shape(&self) -> Vec { - self.0.dims().to_vec() + fn shape(&self, py: Python<'_>) -> PyObject { + PyTuple::new(py, self.0.dims()).to_object(py) } #[getter] From 2370b1675d1ded0e5bb51056780902d87f53ce28 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 1 Jul 2023 22:15:58 +0100 Subject: [PATCH 6/6] More pyo3. --- candle-pyo3/Cargo.toml | 1 + candle-pyo3/src/lib.rs | 54 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 75a52b4d..fd2890f6 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -17,3 +17,4 @@ crate-type = ["cdylib"] [dependencies] candle = { path = "../candle-core", default-features=false } pyo3 = { version = "0.19.0", features = ["extension-module"] } +half = { version = "2.3.1", features = ["num-traits"] } diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index ab280a63..ddf7b554 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,8 +1,10 @@ -use pyo3::types::PyTuple; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; +use pyo3::types::{PyString, PyTuple}; -use ::candle::{Device::Cpu, Tensor}; +use half::{bf16, f16}; + +use ::candle::{DType, Device::Cpu, Tensor}; pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) @@ -23,15 +25,48 @@ impl std::ops::Deref for PyTensor { #[pymethods] impl PyTensor { #[new] + // TODO: Handle arbitrary input dtype and shape. fn new(f: f32) -> PyResult { Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) } + fn values(&self, py: Python<'_>) -> PyResult { + // TODO: Handle arbitrary shapes. + let v = match self.0.dtype() { + // TODO: Use the map bits to avoid enumerating the types. + DType::U8 => self.to_scalar::().map_err(wrap_err)?.to_object(py), + DType::U32 => self.to_scalar::().map_err(wrap_err)?.to_object(py), + DType::F32 => self.to_scalar::().map_err(wrap_err)?.to_object(py), + DType::F64 => self.to_scalar::().map_err(wrap_err)?.to_object(py), + DType::BF16 => self + .to_scalar::() + .map_err(wrap_err)? + .to_f32() + .to_object(py), + DType::F16 => self + .to_scalar::() + .map_err(wrap_err)? + .to_f32() + .to_object(py), + }; + Ok(v) + } + #[getter] fn shape(&self, py: Python<'_>) -> PyObject { PyTuple::new(py, self.0.dims()).to_object(py) } + #[getter] + fn stride(&self, py: Python<'_>) -> PyObject { + PyTuple::new(py, self.0.stride()).to_object(py) + } + + #[getter] + fn dtype(&self, py: Python<'_>) -> PyObject { + PyString::new(py, self.0.dtype().as_str()).to_object(py) + } + #[getter] fn rank(&self) -> usize { self.0.rank() @@ -59,6 +94,21 @@ impl PyTensor { fn __radd__(&self, rhs: &PyAny) -> PyResult { self.__add__(rhs) } + + fn __mul__(&self, rhs: &PyAny) -> PyResult { + let tensor = if let Ok(rhs) = rhs.extract::() { + (&self.0 * &rhs.0).map_err(wrap_err)? + } else if let Ok(rhs) = rhs.extract::() { + (&self.0 * rhs).map_err(wrap_err)? + } else { + Err(PyTypeError::new_err("unsupported for mul"))? + }; + Ok(Self(tensor)) + } + + fn __rmul__(&self, rhs: &PyAny) -> PyResult { + self.__mul__(rhs) + } } #[pyfunction]