diff --git a/src/tensor.rs b/src/tensor.rs index 09e5d66c..161a4787 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -359,23 +359,24 @@ impl Tensor { pub fn to_vec3(&self) -> Result>>> { let (dim1, dim2, dim3) = self.shape().r3()?; - match &self.storage { - Storage::Cpu(cpu_storage) => { - let data = S::cpu_storage_as_slice(cpu_storage)?; - let mut top_rows = vec![]; - let mut src_index = self.strided_index(); - for _idx in 0..dim1 { - let mut rows = vec![]; - for _jdx in 0..dim2 { - let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect(); - rows.push(row) - } - top_rows.push(rows); + let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { + let data = S::cpu_storage_as_slice(cpu_storage)?; + let mut top_rows = vec![]; + let mut src_index = self.strided_index(); + for _idx in 0..dim1 { + let mut rows = vec![]; + for _jdx in 0..dim2 { + let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect(); + rows.push(row) } - assert!(src_index.next().is_none()); - Ok(top_rows) + top_rows.push(rows); } - Storage::Cuda { .. } => todo!(), + assert!(src_index.next().is_none()); + Ok(top_rows) + }; + match &self.storage { + Storage::Cpu(storage) => from_cpu_storage(storage), + Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } }