Support higher order shapes for conversions.

This commit is contained in:
laurent
2023-07-02 07:07:22 +01:00
parent fa58c7643d
commit c62cb73a7f

View File

@ -72,7 +72,37 @@ impl PyTensor {
impl<'a> MapDType for M<'a> {
type Output = PyObject;
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0))
match t.rank() {
0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)),
1 => {
let v = t.to_vec1::<T>().map_err(wrap_err)?;
let v = v.iter().map(|v| v.to_py(self.0)).collect::<Vec<_>>();
Ok(v.to_object(self.0))
}
2 => {
let v = t.to_vec2::<T>().map_err(wrap_err)?;
let v = v
.iter()
.map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
.collect::<Vec<Vec<_>>>();
Ok(v.to_object(self.0))
}
3 => {
let v = t.to_vec3::<T>().map_err(wrap_err)?;
let v = v
.iter()
.map(|v| {
v.iter()
.map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
.collect()
})
.collect::<Vec<Vec<Vec<_>>>>();
Ok(v.to_object(self.0))
}
n => Err(PyTypeError::new_err(format!(
"TODO: conversion to PyObject is not handled for rank {n}"
)))?,
}
}
}
// TODO: Handle arbitrary shapes.