mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add StorageRef. (#2113)
* Add the storage-ref bits. * Add the metal implementation.
This commit is contained in:
@ -133,6 +133,8 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
|||||||
/// after this call.
|
/// after this call.
|
||||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
||||||
|
@ -26,6 +26,17 @@ pub enum CpuStorage {
|
|||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum CpuStorageRef<'a> {
|
||||||
|
U8(&'a [u8]),
|
||||||
|
U32(&'a [u32]),
|
||||||
|
I64(&'a [i64]),
|
||||||
|
BF16(&'a [bf16]),
|
||||||
|
F16(&'a [f16]),
|
||||||
|
F32(&'a [f32]),
|
||||||
|
F64(&'a [f64]),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CpuDevice;
|
pub struct CpuDevice;
|
||||||
|
|
||||||
@ -2445,6 +2456,10 @@ impl BackendDevice for CpuDevice {
|
|||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
|
Ok(T::to_cpu_storage(s))
|
||||||
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Ok(s.clone())
|
Ok(s.clone())
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::backend::BackendDevice;
|
use crate::backend::BackendDevice;
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||||
pub use candle_kernels as kernels;
|
pub use candle_kernels as kernels;
|
||||||
pub use cudarc;
|
pub use cudarc;
|
||||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
||||||
@ -334,6 +334,43 @@ impl BackendDevice for CudaDevice {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
|
let slice = match T::cpu_storage_ref(s) {
|
||||||
|
CpuStorageRef::U8(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::U32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::I64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::BF16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::F16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::F32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
CpuStorageRef::F64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
let slice = match storage {
|
let slice = match storage {
|
||||||
CpuStorage::U8(storage) => {
|
CpuStorage::U8(storage) => {
|
||||||
|
@ -306,6 +306,20 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
|
||||||
|
Device::Cuda(device) => {
|
||||||
|
let storage = device.storage_from_slice(data)?;
|
||||||
|
Ok(Storage::Cuda(storage))
|
||||||
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
let storage = device.storage_from_slice(data)?;
|
||||||
|
Ok(Storage::Metal(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
//! Types for elements that can be stored and manipulated using tensors.
|
//! Types for elements that can be stored and manipulated using tensors.
|
||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::{CpuStorage, Error, Result};
|
use crate::{CpuStorage, CpuStorageRef, Error, Result};
|
||||||
|
|
||||||
/// The different types of elements allowed in tensors.
|
/// The different types of elements allowed in tensors.
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
@ -100,12 +100,14 @@ pub trait WithDType:
|
|||||||
+ 'static
|
+ 'static
|
||||||
+ Send
|
+ Send
|
||||||
+ Sync
|
+ Sync
|
||||||
|
+ std::any::Any
|
||||||
+ crate::cpu::kernels::VecOps
|
+ crate::cpu::kernels::VecOps
|
||||||
{
|
{
|
||||||
const DTYPE: DType;
|
const DTYPE: DType;
|
||||||
|
|
||||||
fn from_f64(v: f64) -> Self;
|
fn from_f64(v: f64) -> Self;
|
||||||
fn to_f64(self) -> f64;
|
fn to_f64(self) -> f64;
|
||||||
|
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||||
|
|
||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||||
@ -129,6 +131,10 @@ macro_rules! with_dtype {
|
|||||||
$to_f64(self)
|
$to_f64(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||||
|
CpuStorageRef::$dtype(data)
|
||||||
|
}
|
||||||
|
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||||
CpuStorage::$dtype(data)
|
CpuStorage::$dtype(data)
|
||||||
}
|
}
|
||||||
|
@ -214,6 +214,10 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
@ -226,6 +226,10 @@ impl crate::backend::BackendDevice for MetalDevice {
|
|||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
@ -74,7 +74,7 @@ mod variable;
|
|||||||
#[cfg(feature = "cudnn")]
|
#[cfg(feature = "cudnn")]
|
||||||
pub use cuda_backend::cudnn;
|
pub use cuda_backend::cudnn;
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
pub use cpu_backend::{CpuStorage, CpuStorageRef};
|
||||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||||
pub use device::{Device, DeviceLocation, NdArray};
|
pub use device::{Device, DeviceLocation, NdArray};
|
||||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||||
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
||||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -1787,6 +1787,19 @@ impl BackendDevice for MetalDevice {
|
|||||||
self.storage_from_cpu_storage(&cpu_storage)
|
self.storage_from_cpu_storage(&cpu_storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
|
let (count, buffer) = match T::cpu_storage_ref(s) {
|
||||||
|
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
|
CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
|
CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
|
CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
|
CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
|
CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
|
CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
|
};
|
||||||
|
Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE))
|
||||||
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||||
let (count, buffer) = match storage {
|
let (count, buffer) = match storage {
|
||||||
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
|
@ -456,7 +456,15 @@ impl Tensor {
|
|||||||
shape: S,
|
shape: S,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Self::new_impl(array, shape.into(), device, false)
|
let shape = shape.into();
|
||||||
|
let n: usize = shape.elem_count();
|
||||||
|
let buffer_size: usize = array.len();
|
||||||
|
if buffer_size != n {
|
||||||
|
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||||
|
}
|
||||||
|
let storage = device.storage_from_slice(array)?;
|
||||||
|
let none = BackpropOp::none();
|
||||||
|
Ok(from_storage(storage, shape, none, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||||
|
Reference in New Issue
Block a user