diff --git a/Makefile b/Makefile index cb472d80..cc967702 100644 --- a/Makefile +++ b/Makefile @@ -9,4 +9,9 @@ clean: test: cargo test +pyo3-test: + cargo build --profile=release-with-debug --package candle-pyo3 + ln -f -s ./target/release-with-debug/libcandle.so candle.so + PYTHONPATH=. python3 candle-pyo3/test.py + all: test diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 136f8a4f..fd013b9b 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,3 +1,4 @@ +// TODO: Handle negative dimension indexes. use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyTuple; @@ -10,7 +11,23 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::(format!("{err:?}")) } -#[derive(Clone)] +#[derive(Clone, Debug)] +struct PyShape(Vec); + +impl<'source> pyo3::FromPyObject<'source> for PyShape { + fn extract(ob: &'source PyAny) -> PyResult { + let dims: Vec = pyo3::FromPyObject::extract(ob)?; + Ok(PyShape(dims)) + } +} + +impl From for ::candle::Shape { + fn from(val: PyShape) -> Self { + val.0.into() + } +} + +#[derive(Clone, Debug)] #[pyclass(name = "Tensor")] struct PyTensor(Tensor); @@ -279,16 +296,15 @@ impl PyTensor { Ok(Self(tensor)) } - // TODO: Add a PyShape type? - fn reshape(&self, shape: Vec) -> PyResult { + fn reshape(&self, shape: PyShape) -> PyResult { Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) } - fn broadcast_as(&self, shape: Vec) -> PyResult { + fn broadcast_as(&self, shape: PyShape) -> PyResult { Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?)) } - fn broadcast_left(&self, shape: Vec) -> PyResult { + fn broadcast_left(&self, shape: PyShape) -> PyResult { Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) } @@ -381,11 +397,59 @@ fn tensor(py: Python<'_>, vs: PyObject) -> PyResult { PyTensor::new(py, vs) } +#[pyfunction] +#[pyo3(signature = (shape, *, device=None))] +fn rand(_py: Python<'_>, shape: PyShape, device: Option) -> PyResult { + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pyfunction] +#[pyo3(signature = (shape, *, device=None))] +fn randn(_py: Python<'_>, shape: PyShape, device: Option) -> PyResult { + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pyfunction] +#[pyo3(signature = (shape, *, dtype=None, device=None))] +fn ones( + _py: Python<'_>, + shape: PyShape, + dtype: Option, + device: Option, +) -> PyResult { + let dtype = dtype.map_or(DType::F32, |dt| dt.0); + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pyfunction] +#[pyo3(signature = (shape, *, dtype=None, device=None))] +fn zeros( + _py: Python<'_>, + shape: PyShape, + dtype: Option, + device: Option, +) -> PyResult { + let dtype = dtype.map_or(DType::F32, |dt| dt.0); + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + #[pymodule] fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_function(wrap_pyfunction!(cat, m)?)?; + m.add_function(wrap_pyfunction!(ones, m)?)?; + m.add_function(wrap_pyfunction!(rand, m)?)?; + m.add_function(wrap_pyfunction!(randn, m)?)?; m.add_function(wrap_pyfunction!(tensor, m)?)?; m.add_function(wrap_pyfunction!(stack, m)?)?; + m.add_function(wrap_pyfunction!(zeros, m)?)?; Ok(()) } diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 8f906060..160a099d 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -13,3 +13,6 @@ t = t.reshape([2, 4]) print(t.matmul(t.t())) print(t.to_dtype("u8")) + +t = candle.randn((5, 3)) +print(t)