add as_cuda_slice_mut to CudaStorage and CudaDType (#2859)

This commit is contained in:
Zack Angelo
2025-04-01 12:34:52 -05:00
committed by GitHub
parent 9541467d6b
commit b4daa03e59
2 changed files with 20 additions and 1 deletions

View File

@ -21,7 +21,9 @@ impl BenchDevice for Device {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device.synchronize()?);
return Ok(device
.synchronize()
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}

View File

@ -1001,6 +1001,7 @@ pub struct CudaStorage {
pub trait CudaDType: Sized {
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
fn as_cuda_slice_mut(s: &mut CudaStorage) -> Result<&mut CudaSlice<Self>>;
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
}
@ -1019,6 +1020,18 @@ macro_rules! cuda_dtype {
}
}
fn as_cuda_slice_mut(s: &mut CudaStorage) -> Result<&mut CudaSlice<Self>> {
match s.slice {
CudaStorageSlice::$dtype(ref mut data) => Ok(data),
_ => Err(crate::Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}
.bt()),
}
}
fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
let slice = CudaStorageSlice::$dtype(slice);
CudaStorage { slice, device }
@ -1042,6 +1055,10 @@ impl CudaStorage {
pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
T::as_cuda_slice(self)
}
pub fn as_cuda_slice_mut<T: CudaDType>(&mut self) -> Result<&mut CudaSlice<T>> {
T::as_cuda_slice_mut(self)
}
}
fn gemm_config<T>(