From dfe197f791b9462c83d7ec9cc141886c868628a7 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 07:19:46 +0100 Subject: [PATCH] Handle more input types to create tensors. --- candle-pyo3/src/lib.rs | 20 ++++++++++++++++++-- candle-pyo3/test.py | 5 +++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 29503c8f..c85b41f0 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -63,8 +63,19 @@ trait MapDType { impl PyTensor { #[new] // TODO: Handle arbitrary input dtype and shape. - fn new(f: f32) -> PyResult { - Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?)) + fn new(py: Python<'_>, vs: PyObject) -> PyResult { + let tensor = if let Ok(vs) = vs.extract::(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>(py) { + Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>(py) { + Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? + } else { + Err(PyTypeError::new_err("incorrect type for tensor"))? + }; + Ok(Self(tensor)) } /// Gets the tensor data as a Python value/array/array of array/... @@ -167,6 +178,11 @@ impl PyTensor { fn __rmul__(&self, rhs: &PyAny) -> PyResult { self.__mul__(rhs) } + + // TODO: Add a PyShape type? + fn reshape(&self, shape: Vec) -> PyResult { + Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) + } } #[pyfunction] diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 0db3fab9..aad5e8ae 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -4,3 +4,8 @@ t = candle.Tensor(42.0) print(t) print("shape", t.shape, t.rank) print(t + t) + +t = candle.Tensor([3, 1, 4, 1, 5, 9, 2, 6]) +print(t) +print(t+t) +print(t.reshape([2, 4]))