Add StorageRef. (#2113)

* Add the storage-ref bits.

* Add the metal implementation.
This commit is contained in:
Laurent Mazare
2024-04-23 13:23:27 +02:00
committed by GitHub
parent b2e816752b
commit 8a05743a21
10 changed files with 108 additions and 5 deletions

View File

@ -133,6 +133,8 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
/// after this call. /// after this call.
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>; unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>; fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>; fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;

View File

@ -26,6 +26,17 @@ pub enum CpuStorage {
F64(Vec<f64>), F64(Vec<f64>),
} }
#[derive(Debug, Clone)]
pub enum CpuStorageRef<'a> {
U8(&'a [u8]),
U32(&'a [u32]),
I64(&'a [i64]),
BF16(&'a [bf16]),
F16(&'a [f16]),
F32(&'a [f32]),
F64(&'a [f64]),
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CpuDevice; pub struct CpuDevice;
@ -2445,6 +2456,10 @@ impl BackendDevice for CpuDevice {
true true
} }
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
Ok(T::to_cpu_storage(s))
}
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> { fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
Ok(s.clone()) Ok(s.clone())
} }

View File

@ -1,5 +1,5 @@
use crate::backend::BackendDevice; use crate::backend::BackendDevice;
use crate::{CpuStorage, DType, Layout, Result, Shape}; use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels; pub use candle_kernels as kernels;
pub use cudarc; pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
@ -334,6 +334,43 @@ impl BackendDevice for CudaDevice {
}) })
} }
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> { fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage { let slice = match storage {
CpuStorage::U8(storage) => { CpuStorage::U8(storage) => {

View File

@ -306,6 +306,20 @@ impl Device {
} }
} }
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
Device::Cuda(device) => {
let storage = device.storage_from_slice(data)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.storage_from_slice(data)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> { pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self { match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),

View File

@ -1,7 +1,7 @@
//! Types for elements that can be stored and manipulated using tensors. //! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)] #![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage; use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result}; use crate::{CpuStorage, CpuStorageRef, Error, Result};
/// The different types of elements allowed in tensors. /// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
@ -100,12 +100,14 @@ pub trait WithDType:
+ 'static + 'static
+ Send + Send
+ Sync + Sync
+ std::any::Any
+ crate::cpu::kernels::VecOps + crate::cpu::kernels::VecOps
{ {
const DTYPE: DType; const DTYPE: DType;
fn from_f64(v: f64) -> Self; fn from_f64(v: f64) -> Self;
fn to_f64(self) -> f64; fn to_f64(self) -> f64;
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage; fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
fn to_cpu_storage(data: &[Self]) -> CpuStorage { fn to_cpu_storage(data: &[Self]) -> CpuStorage {
@ -129,6 +131,10 @@ macro_rules! with_dtype {
$to_f64(self) $to_f64(self)
} }
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
CpuStorageRef::$dtype(data)
}
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage { fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
CpuStorage::$dtype(data) CpuStorage::$dtype(data)
} }

View File

@ -214,6 +214,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> { fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }

View File

@ -226,6 +226,10 @@ impl crate::backend::BackendDevice for MetalDevice {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> { fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }

View File

@ -74,7 +74,7 @@ mod variable;
#[cfg(feature = "cudnn")] #[cfg(feature = "cudnn")]
pub use cuda_backend::cudnn; pub use cuda_backend::cudnn;
pub use cpu_backend::CpuStorage; pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
pub use device::{Device, DeviceLocation, NdArray}; pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};

View File

@ -1,7 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage}; use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape}; use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels}; use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
use metal::{Buffer, MTLResourceOptions, NSUInteger}; use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::collections::HashMap; use std::collections::HashMap;
@ -1787,6 +1787,19 @@ impl BackendDevice for MetalDevice {
self.storage_from_cpu_storage(&cpu_storage) self.storage_from_cpu_storage(&cpu_storage)
} }
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let (count, buffer) = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
};
Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE))
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let (count, buffer) = match storage { let (count, buffer) = match storage {
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),

View File

@ -456,7 +456,15 @@ impl Tensor {
shape: S, shape: S,
device: &Device, device: &Device,
) -> Result<Self> { ) -> Result<Self> {
Self::new_impl(array, shape.into(), device, false) let shape = shape.into();
let n: usize = shape.elem_count();
let buffer_size: usize = array.len();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let storage = device.storage_from_slice(array)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, false))
} }
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {