mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Support higher order shapes for conversions.
This commit is contained in:
@ -72,7 +72,37 @@ impl PyTensor {
|
|||||||
impl<'a> MapDType for M<'a> {
|
impl<'a> MapDType for M<'a> {
|
||||||
type Output = PyObject;
|
type Output = PyObject;
|
||||||
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
|
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
|
||||||
Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0))
|
match t.rank() {
|
||||||
|
0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)),
|
||||||
|
1 => {
|
||||||
|
let v = t.to_vec1::<T>().map_err(wrap_err)?;
|
||||||
|
let v = v.iter().map(|v| v.to_py(self.0)).collect::<Vec<_>>();
|
||||||
|
Ok(v.to_object(self.0))
|
||||||
|
}
|
||||||
|
2 => {
|
||||||
|
let v = t.to_vec2::<T>().map_err(wrap_err)?;
|
||||||
|
let v = v
|
||||||
|
.iter()
|
||||||
|
.map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
|
||||||
|
.collect::<Vec<Vec<_>>>();
|
||||||
|
Ok(v.to_object(self.0))
|
||||||
|
}
|
||||||
|
3 => {
|
||||||
|
let v = t.to_vec3::<T>().map_err(wrap_err)?;
|
||||||
|
let v = v
|
||||||
|
.iter()
|
||||||
|
.map(|v| {
|
||||||
|
v.iter()
|
||||||
|
.map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect::<Vec<Vec<Vec<_>>>>();
|
||||||
|
Ok(v.to_object(self.0))
|
||||||
|
}
|
||||||
|
n => Err(PyTypeError::new_err(format!(
|
||||||
|
"TODO: conversion to PyObject is not handled for rank {n}"
|
||||||
|
)))?,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: Handle arbitrary shapes.
|
// TODO: Handle arbitrary shapes.
|
||||||
|
Reference in New Issue
Block a user