mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Support higher order shapes for conversions.
This commit is contained in:
@ -72,7 +72,37 @@ impl PyTensor {
|
||||
impl<'a> MapDType for M<'a> {
|
||||
type Output = PyObject;
|
||||
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
|
||||
Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0))
|
||||
match t.rank() {
|
||||
0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)),
|
||||
1 => {
|
||||
let v = t.to_vec1::<T>().map_err(wrap_err)?;
|
||||
let v = v.iter().map(|v| v.to_py(self.0)).collect::<Vec<_>>();
|
||||
Ok(v.to_object(self.0))
|
||||
}
|
||||
2 => {
|
||||
let v = t.to_vec2::<T>().map_err(wrap_err)?;
|
||||
let v = v
|
||||
.iter()
|
||||
.map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
|
||||
.collect::<Vec<Vec<_>>>();
|
||||
Ok(v.to_object(self.0))
|
||||
}
|
||||
3 => {
|
||||
let v = t.to_vec3::<T>().map_err(wrap_err)?;
|
||||
let v = v
|
||||
.iter()
|
||||
.map(|v| {
|
||||
v.iter()
|
||||
.map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
|
||||
.collect()
|
||||
})
|
||||
.collect::<Vec<Vec<Vec<_>>>>();
|
||||
Ok(v.to_object(self.0))
|
||||
}
|
||||
n => Err(PyTypeError::new_err(format!(
|
||||
"TODO: conversion to PyObject is not handled for rank {n}"
|
||||
)))?,
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: Handle arbitrary shapes.
|
||||
|
Reference in New Issue
Block a user