mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add StorageRef. (#2113)
* Add the storage-ref bits. * Add the metal implementation.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
@ -1787,6 +1787,19 @@ impl BackendDevice for MetalDevice {
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
};
|
||||
Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE))
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match storage {
|
||||
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
|
Reference in New Issue
Block a user