Add StorageRef. (#2113)

* Add the storage-ref bits.

* Add the metal implementation.
This commit is contained in:
Laurent Mazare
2024-04-23 13:23:27 +02:00
committed by GitHub
parent b2e816752b
commit 8a05743a21
10 changed files with 108 additions and 5 deletions

View File

@ -1,5 +1,5 @@
use crate::backend::BackendDevice;
use crate::{CpuStorage, DType, Layout, Result, Shape};
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
@ -334,6 +334,43 @@ impl BackendDevice for CudaDevice {
})
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {