From 8a05743a21768405217576a1b9557936be74ed90 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 23 Apr 2024 13:23:27 +0200 Subject: [PATCH] Add StorageRef. (#2113) * Add the storage-ref bits. * Add the metal implementation. --- candle-core/src/backend.rs | 2 ++ candle-core/src/cpu_backend/mod.rs | 15 ++++++++++ candle-core/src/cuda_backend/device.rs | 39 +++++++++++++++++++++++++- candle-core/src/device.rs | 14 +++++++++ candle-core/src/dtype.rs | 8 +++++- candle-core/src/dummy_cuda_backend.rs | 4 +++ candle-core/src/dummy_metal_backend.rs | 4 +++ candle-core/src/lib.rs | 2 +- candle-core/src/metal_backend/mod.rs | 15 +++++++++- candle-core/src/tensor.rs | 10 ++++++- 10 files changed, 108 insertions(+), 5 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 6fe9d9df..afe3e407 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -133,6 +133,8 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { /// after this call. unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result; + fn storage_from_slice(&self, _: &[T]) -> Result; + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result; fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result; diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index c8cf7e1e..299b1e6e 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -26,6 +26,17 @@ pub enum CpuStorage { F64(Vec), } +#[derive(Debug, Clone)] +pub enum CpuStorageRef<'a> { + U8(&'a [u8]), + U32(&'a [u32]), + I64(&'a [i64]), + BF16(&'a [bf16]), + F16(&'a [f16]), + F32(&'a [f32]), + F64(&'a [f64]), +} + #[derive(Debug, Clone)] pub struct CpuDevice; @@ -2445,6 +2456,10 @@ impl BackendDevice for CpuDevice { true } + fn storage_from_slice(&self, s: &[T]) -> Result { + Ok(T::to_cpu_storage(s)) + } + fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result { Ok(s.clone()) } diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 40b7293b..0aa58cac 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -1,5 +1,5 @@ use crate::backend::BackendDevice; -use crate::{CpuStorage, DType, Layout, Result, Shape}; +use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; @@ -334,6 +334,43 @@ impl BackendDevice for CudaDevice { }) } + fn storage_from_slice(&self, s: &[T]) -> Result { + let slice = match T::cpu_storage_ref(s) { + CpuStorageRef::U8(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U8(data) + } + CpuStorageRef::U32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U32(data) + } + CpuStorageRef::I64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I64(data) + } + CpuStorageRef::BF16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::BF16(data) + } + CpuStorageRef::F16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F16(data) + } + CpuStorageRef::F32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F32(data) + } + CpuStorageRef::F64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { let slice = match storage { CpuStorage::U8(storage) => { diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index d0bec4f7..1cd26167 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -306,6 +306,20 @@ impl Device { } } + pub(crate) fn storage_from_slice(&self, data: &[D]) -> Result { + match self { + Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())), + Device::Cuda(device) => { + let storage = device.storage_from_slice(data)?; + Ok(Storage::Cuda(storage)) + } + Device::Metal(device) => { + let storage = device.storage_from_slice(data)?; + Ok(Storage::Metal(storage)) + } + } + } + pub(crate) fn storage(&self, array: A) -> Result { match self { Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 1a698a35..de6cddc3 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -1,7 +1,7 @@ //! Types for elements that can be stored and manipulated using tensors. #![allow(clippy::redundant_closure_call)] use crate::backend::BackendStorage; -use crate::{CpuStorage, Error, Result}; +use crate::{CpuStorage, CpuStorageRef, Error, Result}; /// The different types of elements allowed in tensors. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] @@ -100,12 +100,14 @@ pub trait WithDType: + 'static + Send + Sync + + std::any::Any + crate::cpu::kernels::VecOps { const DTYPE: DType; fn from_f64(v: f64) -> Self; fn to_f64(self) -> f64; + fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>; fn to_cpu_storage_owned(data: Vec) -> CpuStorage; fn to_cpu_storage(data: &[Self]) -> CpuStorage { @@ -129,6 +131,10 @@ macro_rules! with_dtype { $to_f64(self) } + fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> { + CpuStorageRef::$dtype(data) + } + fn to_cpu_storage_owned(data: Vec) -> CpuStorage { CpuStorage::$dtype(data) } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 015635a5..a85ed61c 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -214,6 +214,10 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + fn storage_from_slice(&self, _: &[T]) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index 47a166c1..a1c2394d 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -226,6 +226,10 @@ impl crate::backend::BackendDevice for MetalDevice { Err(Error::NotCompiledWithMetalSupport) } + fn storage_from_slice(&self, _: &[T]) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 1f57ca9b..bafad1b6 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -74,7 +74,7 @@ mod variable; #[cfg(feature = "cudnn")] pub use cuda_backend::cudnn; -pub use cpu_backend::CpuStorage; +pub use cpu_backend::{CpuStorage, CpuStorageRef}; pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 12dba381..1396899b 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1,7 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; -use crate::{CpuStorage, DType, Layout, Result, Shape}; +use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels}; use metal::{Buffer, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; @@ -1787,6 +1787,19 @@ impl BackendDevice for MetalDevice { self.storage_from_cpu_storage(&cpu_storage) } + fn storage_from_slice(&self, s: &[T]) -> Result { + let (count, buffer) = match T::cpu_storage_ref(s) { + CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + }; + Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE)) + } + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { let (count, buffer) = match storage { CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a39d6b18..dd1b44b0 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -456,7 +456,15 @@ impl Tensor { shape: S, device: &Device, ) -> Result { - Self::new_impl(array, shape.into(), device, false) + let shape = shape.into(); + let n: usize = shape.elem_count(); + let buffer_size: usize = array.len(); + if buffer_size != n { + return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); + } + let storage = device.storage_from_slice(array)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, shape, none, false)) } pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {