From fa58c7643ded869a345b8df538dc38d21684ac88 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 06:58:10 +0100 Subject: [PATCH 1/7] Add a trait to avoid repeating the dtype matching. --- candle-pyo3/src/lib.rs | 67 +++++++++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index ddf7b554..1d3e4efd 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -4,7 +4,7 @@ use pyo3::types::{PyString, PyTuple}; use half::{bf16, f16}; -use ::candle::{DType, Device::Cpu, Tensor}; +use ::candle::{DType, Device::Cpu, Tensor, WithDType}; pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) @@ -22,6 +22,43 @@ impl std::ops::Deref for PyTensor { } } +trait PyDType: WithDType { + fn to_py(&self, py: Python<'_>) -> PyObject; +} + +macro_rules! pydtype { + ($ty:ty, $conv:expr) => { + impl PyDType for $ty { + fn to_py(&self, py: Python<'_>) -> PyObject { + $conv(*self).to_object(py) + } + } + }; +} +pydtype!(u8, |v| v); +pydtype!(u32, |v| v); +pydtype!(f16, f32::from); +pydtype!(bf16, f32::from); +pydtype!(f32, |v| v); +pydtype!(f64, |v| v); + +// TODO: Something similar to this should probably be a part of candle core. +trait MapDType { + type Output; + fn f(&self, t: &Tensor) -> PyResult; + + fn map(&self, t: &Tensor) -> PyResult { + match t.dtype() { + DType::U8 => self.f::(t), + DType::U32 => self.f::(t), + DType::BF16 => self.f::(t), + DType::F16 => self.f::(t), + DType::F32 => self.f::(t), + DType::F64 => self.f::(t), + } + } +} + #[pymethods] impl PyTensor { #[new] @@ -30,26 +67,16 @@ impl PyTensor { Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) } - fn values(&self, py: Python<'_>) -> PyResult { + fn scalar(&self, py: Python<'_>) -> PyResult { + struct M<'a>(Python<'a>); + impl<'a> MapDType for M<'a> { + type Output = PyObject; + fn f(&self, t: &Tensor) -> PyResult { + Ok(t.to_scalar::().map_err(wrap_err)?.to_py(self.0)) + } + } // TODO: Handle arbitrary shapes. - let v = match self.0.dtype() { - // TODO: Use the map bits to avoid enumerating the types. - DType::U8 => self.to_scalar::().map_err(wrap_err)?.to_object(py), - DType::U32 => self.to_scalar::().map_err(wrap_err)?.to_object(py), - DType::F32 => self.to_scalar::().map_err(wrap_err)?.to_object(py), - DType::F64 => self.to_scalar::().map_err(wrap_err)?.to_object(py), - DType::BF16 => self - .to_scalar::() - .map_err(wrap_err)? - .to_f32() - .to_object(py), - DType::F16 => self - .to_scalar::() - .map_err(wrap_err)? - .to_f32() - .to_object(py), - }; - Ok(v) + M(py).map(self) } #[getter] From c62cb73a7f7487ee76ba0d93c1cde7706b1025b3 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 07:07:22 +0100 Subject: [PATCH 2/7] Support higher order shapes for conversions. --- candle-pyo3/src/lib.rs | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 1d3e4efd..4328ac01 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -72,7 +72,37 @@ impl PyTensor { impl<'a> MapDType for M<'a> { type Output = PyObject; fn f(&self, t: &Tensor) -> PyResult { - Ok(t.to_scalar::().map_err(wrap_err)?.to_py(self.0)) + match t.rank() { + 0 => Ok(t.to_scalar::().map_err(wrap_err)?.to_py(self.0)), + 1 => { + let v = t.to_vec1::().map_err(wrap_err)?; + let v = v.iter().map(|v| v.to_py(self.0)).collect::>(); + Ok(v.to_object(self.0)) + } + 2 => { + let v = t.to_vec2::().map_err(wrap_err)?; + let v = v + .iter() + .map(|v| v.iter().map(|v| v.to_py(self.0)).collect()) + .collect::>>(); + Ok(v.to_object(self.0)) + } + 3 => { + let v = t.to_vec3::().map_err(wrap_err)?; + let v = v + .iter() + .map(|v| { + v.iter() + .map(|v| v.iter().map(|v| v.to_py(self.0)).collect()) + .collect() + }) + .collect::>>>(); + Ok(v.to_object(self.0)) + } + n => Err(PyTypeError::new_err(format!( + "TODO: conversion to PyObject is not handled for rank {n}" + )))?, + } } } // TODO: Handle arbitrary shapes. From 4a28dcf828ba8393a8167218a0df18fa4ad1f37c Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 07:08:11 +0100 Subject: [PATCH 3/7] Rename the method. --- candle-pyo3/src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 4328ac01..29503c8f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -67,7 +67,8 @@ impl PyTensor { Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) } - fn scalar(&self, py: Python<'_>) -> PyResult { + /// Gets the tensor data as a Python value/array/array of array/... + fn values(&self, py: Python<'_>) -> PyResult { struct M<'a>(Python<'a>); impl<'a> MapDType for M<'a> { type Output = PyObject; From dfe197f791b9462c83d7ec9cc141886c868628a7 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 07:19:46 +0100 Subject: [PATCH 4/7] Handle more input types to create tensors. --- candle-pyo3/src/lib.rs | 20 ++++++++++++++++++-- candle-pyo3/test.py | 5 +++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 29503c8f..c85b41f0 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -63,8 +63,19 @@ trait MapDType { impl PyTensor { #[new] // TODO: Handle arbitrary input dtype and shape. - fn new(f: f32) -> PyResult { - Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) + fn new(py: Python<'_>, vs: PyObject) -> PyResult { + let tensor = if let Ok(vs) = vs.extract::(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>(py) { + Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>(py) { + Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? + } else { + Err(PyTypeError::new_err("incorrect type for tensor"))? + }; + Ok(Self(tensor)) } /// Gets the tensor data as a Python value/array/array of array/... @@ -167,6 +178,11 @@ impl PyTensor { fn __rmul__(&self, rhs: &PyAny) -> PyResult { self.__mul__(rhs) } + + // TODO: Add a PyShape type? + fn reshape(&self, shape: Vec) -> PyResult { + Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) + } } #[pyfunction] diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 0db3fab9..aad5e8ae 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -4,3 +4,8 @@ t = candle.Tensor(42.0) print(t) print("shape", t.shape, t.rank) print(t + t) + +t = candle.Tensor([3, 1, 4, 1, 5, 9, 2, 6]) +print(t) +print(t+t) +print(t.reshape([2, 4])) From 9a9858bbe00adc0be84a63df56d5d26078bf81a1 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 07:30:00 +0100 Subject: [PATCH 5/7] Expose a couple more ops. --- candle-pyo3/src/lib.rs | 87 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 82 insertions(+), 5 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index c85b41f0..74847e5e 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -155,7 +155,7 @@ impl PyTensor { } else if let Ok(rhs) = rhs.extract::() { (&self.0 + rhs).map_err(wrap_err)? } else { - Err(PyTypeError::new_err("unsupported for add"))? + Err(PyTypeError::new_err("unsupported rhs for add"))? }; Ok(Self(tensor)) } @@ -170,7 +170,7 @@ impl PyTensor { } else if let Ok(rhs) = rhs.extract::() { (&self.0 * rhs).map_err(wrap_err)? } else { - Err(PyTypeError::new_err("unsupported for mul"))? + Err(PyTypeError::new_err("unsupported rhs for mul"))? }; Ok(Self(tensor)) } @@ -179,21 +179,98 @@ impl PyTensor { self.__mul__(rhs) } + fn __sub__(&self, rhs: &PyAny) -> PyResult { + let tensor = if let Ok(rhs) = rhs.extract::() { + (&self.0 - &rhs.0).map_err(wrap_err)? + } else if let Ok(rhs) = rhs.extract::() { + (&self.0 - rhs).map_err(wrap_err)? + } else { + Err(PyTypeError::new_err("unsupported rhs for sub"))? + }; + Ok(Self(tensor)) + } + // TODO: Add a PyShape type? fn reshape(&self, shape: Vec) -> PyResult { Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) } + + fn broadcast_as(&self, shape: Vec) -> PyResult { + Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?)) + } + + fn broadcast_left(&self, shape: Vec) -> PyResult { + Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) + } + + fn squeeze(&self, dim: usize) -> PyResult { + Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?)) + } + + fn unsqueeze(&self, dim: usize) -> PyResult { + Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?)) + } + + fn get(&self, index: usize) -> PyResult { + Ok(PyTensor(self.0.get(index).map_err(wrap_err)?)) + } + + fn transpose(&self, dim1: usize, dim2: usize) -> PyResult { + Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) + } + + fn sum_all(&self) -> PyResult { + Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?)) + } + + fn flatten_all(&self) -> PyResult { + Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?)) + } + + fn t(&self) -> PyResult { + Ok(PyTensor(self.0.t().map_err(wrap_err)?)) + } + + fn contiguous(&self) -> PyResult { + Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?)) + } + + fn is_contiguous(&self) -> bool { + self.0.is_contiguous() + } + + fn is_fortran_contiguous(&self) -> bool { + self.0.is_fortran_contiguous() + } + + fn detach(&self) -> PyResult { + Ok(PyTensor(self.0.detach().map_err(wrap_err)?)) + } + + fn copy(&self) -> PyResult { + Ok(PyTensor(self.0.copy().map_err(wrap_err)?)) + } +} + +/// Concatenate the tensors across one axis. +#[pyfunction] +fn cat(tensors: Vec, dim: usize) -> PyResult { + let tensors = tensors.into_iter().map(|t| t.0).collect::>(); + let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?; + Ok(PyTensor(tensor)) } #[pyfunction] -fn add(tensor: &PyTensor, f: f64) -> PyResult { - let tensor = (&tensor.0 + f).map_err(wrap_err)?; +fn stack(tensors: Vec, dim: usize) -> PyResult { + let tensors = tensors.into_iter().map(|t| t.0).collect::>(); + let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pymodule] fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; - m.add_function(wrap_pyfunction!(add, m)?)?; + m.add_function(wrap_pyfunction!(cat, m)?)?; + m.add_function(wrap_pyfunction!(stack, m)?)?; Ok(()) } From 5b8c6764b05cfe82340101372549aa2a97a0ffbb Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 07:34:14 +0100 Subject: [PATCH 6/7] Add matmul/where_cond. --- candle-pyo3/src/lib.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 74847e5e..b1504ada 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -149,6 +149,16 @@ impl PyTensor { self.__repr__() } + fn matmul(&self, rhs: &Self) -> PyResult { + Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?)) + } + + fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult { + Ok(PyTensor( + self.0.where_cond(on_true, on_false).map_err(wrap_err)?, + )) + } + fn __add__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { (&self.0 + &rhs.0).map_err(wrap_err)? @@ -219,6 +229,15 @@ impl PyTensor { Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) } + fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult { + Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?)) + } + + fn sum(&self, dims: Vec) -> PyResult { + // TODO: Support a single dim as input? + Ok(PyTensor(self.0.sum(dims.as_slice()).map_err(wrap_err)?)) + } + fn sum_all(&self) -> PyResult { Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?)) } From d38897461b5e2947bf6b552a7453a078dc7a9021 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 07:37:17 +0100 Subject: [PATCH 7/7] Add to the example. --- candle-pyo3/test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index aad5e8ae..d63f752b 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -5,7 +5,8 @@ print(t) print("shape", t.shape, t.rank) print(t + t) -t = candle.Tensor([3, 1, 4, 1, 5, 9, 2, 6]) +t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6]) print(t) print(t+t) -print(t.reshape([2, 4])) +t = t.reshape([2, 4]) +print(t.matmul(t.t()))