mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
add as_cuda_slice_mut to CudaStorage and CudaDType (#2859)
This commit is contained in:
@ -21,7 +21,9 @@ impl BenchDevice for Device {
|
|||||||
Device::Cpu => Ok(()),
|
Device::Cpu => Ok(()),
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
return Ok(device.synchronize()?);
|
return Ok(device
|
||||||
|
.synchronize()
|
||||||
|
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||||
}
|
}
|
||||||
|
@ -1001,6 +1001,7 @@ pub struct CudaStorage {
|
|||||||
|
|
||||||
pub trait CudaDType: Sized {
|
pub trait CudaDType: Sized {
|
||||||
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
|
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
|
||||||
|
fn as_cuda_slice_mut(s: &mut CudaStorage) -> Result<&mut CudaSlice<Self>>;
|
||||||
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
|
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1019,6 +1020,18 @@ macro_rules! cuda_dtype {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn as_cuda_slice_mut(s: &mut CudaStorage) -> Result<&mut CudaSlice<Self>> {
|
||||||
|
match s.slice {
|
||||||
|
CudaStorageSlice::$dtype(ref mut data) => Ok(data),
|
||||||
|
_ => Err(crate::Error::UnexpectedDType {
|
||||||
|
expected: DType::$dtype,
|
||||||
|
got: s.dtype(),
|
||||||
|
msg: "unexpected dtype",
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
|
fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
|
||||||
let slice = CudaStorageSlice::$dtype(slice);
|
let slice = CudaStorageSlice::$dtype(slice);
|
||||||
CudaStorage { slice, device }
|
CudaStorage { slice, device }
|
||||||
@ -1042,6 +1055,10 @@ impl CudaStorage {
|
|||||||
pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
|
pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
|
||||||
T::as_cuda_slice(self)
|
T::as_cuda_slice(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn as_cuda_slice_mut<T: CudaDType>(&mut self) -> Result<&mut CudaSlice<T>> {
|
||||||
|
T::as_cuda_slice_mut(self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gemm_config<T>(
|
fn gemm_config<T>(
|
||||||
|
Reference in New Issue
Block a user