From c62cb73a7f7487ee76ba0d93c1cde7706b1025b3 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 2 Jul 2023 07:07:22 +0100 Subject: [PATCH] Support higher order shapes for conversions. --- candle-pyo3/src/lib.rs | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 1d3e4efd..4328ac01 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -72,7 +72,37 @@ impl PyTensor { impl<'a> MapDType for M<'a> { type Output = PyObject; fn f(&self, t: &Tensor) -> PyResult { - Ok(t.to_scalar::().map_err(wrap_err)?.to_py(self.0)) + match t.rank() { + 0 => Ok(t.to_scalar::().map_err(wrap_err)?.to_py(self.0)), + 1 => { + let v = t.to_vec1::().map_err(wrap_err)?; + let v = v.iter().map(|v| v.to_py(self.0)).collect::>(); + Ok(v.to_object(self.0)) + } + 2 => { + let v = t.to_vec2::().map_err(wrap_err)?; + let v = v + .iter() + .map(|v| v.iter().map(|v| v.to_py(self.0)).collect()) + .collect::>>(); + Ok(v.to_object(self.0)) + } + 3 => { + let v = t.to_vec3::().map_err(wrap_err)?; + let v = v + .iter() + .map(|v| { + v.iter() + .map(|v| v.iter().map(|v| v.to_py(self.0)).collect()) + .collect() + }) + .collect::>>>(); + Ok(v.to_object(self.0)) + } + n => Err(PyTypeError::new_err(format!( + "TODO: conversion to PyObject is not handled for rank {n}" + )))?, + } } } // TODO: Handle arbitrary shapes.