mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
add as_cuda_slice_mut to CudaStorage and CudaDType (#2859)
This commit is contained in:
@ -21,7 +21,9 @@ impl BenchDevice for Device {
|
||||
Device::Cpu => Ok(()),
|
||||
Device::Cuda(device) => {
|
||||
#[cfg(feature = "cuda")]
|
||||
return Ok(device.synchronize()?);
|
||||
return Ok(device
|
||||
.synchronize()
|
||||
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||
}
|
||||
|
@ -1001,6 +1001,7 @@ pub struct CudaStorage {
|
||||
|
||||
pub trait CudaDType: Sized {
|
||||
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
|
||||
fn as_cuda_slice_mut(s: &mut CudaStorage) -> Result<&mut CudaSlice<Self>>;
|
||||
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
|
||||
}
|
||||
|
||||
@ -1019,6 +1020,18 @@ macro_rules! cuda_dtype {
|
||||
}
|
||||
}
|
||||
|
||||
fn as_cuda_slice_mut(s: &mut CudaStorage) -> Result<&mut CudaSlice<Self>> {
|
||||
match s.slice {
|
||||
CudaStorageSlice::$dtype(ref mut data) => Ok(data),
|
||||
_ => Err(crate::Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
|
||||
let slice = CudaStorageSlice::$dtype(slice);
|
||||
CudaStorage { slice, device }
|
||||
@ -1042,6 +1055,10 @@ impl CudaStorage {
|
||||
pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
|
||||
T::as_cuda_slice(self)
|
||||
}
|
||||
|
||||
pub fn as_cuda_slice_mut<T: CudaDType>(&mut self) -> Result<&mut CudaSlice<T>> {
|
||||
T::as_cuda_slice_mut(self)
|
||||
}
|
||||
}
|
||||
|
||||
fn gemm_config<T>(
|
||||
|
Reference in New Issue
Block a user