From b4daa03e598b516ee4dca5864b70f7254642b7bd Mon Sep 17 00:00:00 2001 From: Zack Angelo Date: Tue, 1 Apr 2025 12:34:52 -0500 Subject: [PATCH] add as_cuda_slice_mut to CudaStorage and CudaDType (#2859) --- candle-core/benches/benchmarks/mod.rs | 4 +++- candle-core/src/cuda_backend/mod.rs | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 721b292d..b0d2244f 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -21,7 +21,9 @@ impl BenchDevice for Device { Device::Cpu => Ok(()), Device::Cuda(device) => { #[cfg(feature = "cuda")] - return Ok(device.synchronize()?); + return Ok(device + .synchronize() + .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?); #[cfg(not(feature = "cuda"))] panic!("Cuda device without cuda feature enabled: {:?}", device) } diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 2cd97c18..c71b9694 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1001,6 +1001,7 @@ pub struct CudaStorage { pub trait CudaDType: Sized { fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice>; + fn as_cuda_slice_mut(s: &mut CudaStorage) -> Result<&mut CudaSlice>; fn wrap_cuda_slice(s: CudaSlice, dev: CudaDevice) -> CudaStorage; } @@ -1019,6 +1020,18 @@ macro_rules! cuda_dtype { } } + fn as_cuda_slice_mut(s: &mut CudaStorage) -> Result<&mut CudaSlice> { + match s.slice { + CudaStorageSlice::$dtype(ref mut data) => Ok(data), + _ => Err(crate::Error::UnexpectedDType { + expected: DType::$dtype, + got: s.dtype(), + msg: "unexpected dtype", + } + .bt()), + } + } + fn wrap_cuda_slice(slice: CudaSlice, device: CudaDevice) -> CudaStorage { let slice = CudaStorageSlice::$dtype(slice); CudaStorage { slice, device } @@ -1042,6 +1055,10 @@ impl CudaStorage { pub fn as_cuda_slice(&self) -> Result<&CudaSlice> { T::as_cuda_slice(self) } + + pub fn as_cuda_slice_mut(&mut self) -> Result<&mut CudaSlice> { + T::as_cuda_slice_mut(self) + } } fn gemm_config(