Add StorageRef. (#2113)

* Add the storage-ref bits.

* Add the metal implementation.
This commit is contained in:
Laurent Mazare
2024-04-23 13:23:27 +02:00
committed by GitHub
parent b2e816752b
commit 8a05743a21
10 changed files with 108 additions and 5 deletions

View File

@ -1,7 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
@ -1787,6 +1787,19 @@ impl BackendDevice for MetalDevice {
self.storage_from_cpu_storage(&cpu_storage)
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let (count, buffer) = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
};
Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE))
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let (count, buffer) = match storage {
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),