From 87a37b3bf3b6fd5034269c10c21c8f91e0223eb0 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Jun 2023 11:01:49 +0100 Subject: [PATCH] Retrieve data from the gpu. --- examples/cuda_basics.rs | 10 ++++++---- src/tensor.rs | 36 +++++++++++++++++++----------------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/examples/cuda_basics.rs b/examples/cuda_basics.rs index 0a4825fa..046091a3 100644 --- a/examples/cuda_basics.rs +++ b/examples/cuda_basics.rs @@ -1,12 +1,14 @@ use anyhow::Result; -use candle::{Device, Tensor}; +use candle::{DType, Device, Tensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; let x = Tensor::new(&[3f32, 1., 4., 1., 5.], &device)?; println!("{:?}", x.to_vec1::()?); - let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?; - let z = (y * 3.)?; - println!("{:?}", z.to_vec1::()?); + let x = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?; + let y = (x * 3.)?; + println!("{:?}", y.to_vec1::()?); + let x = Tensor::ones((3, 2), DType::F32, &device)?; + println!("{:?}", x.to_vec2::()?); Ok(()) } diff --git a/src/tensor.rs b/src/tensor.rs index 02105573..e8e01d5c 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -204,12 +204,13 @@ impl Tensor { shape: self.shape().clone(), }); } + let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { + let data = S::cpu_storage_as_slice(cpu_storage)?; + Ok::<_, Error>(data[0]) + }; match &self.storage { - Storage::Cpu(cpu_storage) => { - let data = S::cpu_storage_as_slice(cpu_storage)?; - Ok(data[0]) - } - Storage::Cuda { .. } => todo!(), + Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), + Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } @@ -261,19 +262,20 @@ impl Tensor { pub fn to_vec2(&self) -> Result>> { let (dim1, dim2) = self.shape().r2()?; - match &self.storage { - Storage::Cpu(cpu_storage) => { - let data = S::cpu_storage_as_slice(cpu_storage)?; - let mut rows = vec![]; - let mut src_index = self.strided_index(); - for _idx_row in 0..dim1 { - let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect(); - rows.push(row) - } - assert!(src_index.next().is_none()); - Ok(rows) + let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { + let data = S::cpu_storage_as_slice(cpu_storage)?; + let mut rows = vec![]; + let mut src_index = self.strided_index(); + for _idx_row in 0..dim1 { + let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect(); + rows.push(row) } - Storage::Cuda { .. } => todo!(), + assert!(src_index.next().is_none()); + Ok(rows) + }; + match &self.storage { + Storage::Cpu(storage) => from_cpu_storage(storage), + Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } }