diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index aa35703d..c897510e 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -9,6 +9,7 @@ pub(crate) trait BackendStorage: Sized { fn device(&self) -> &Self::Device; + // Maybe this should return a Cow instead so that no copy is done on the cpu case. fn to_cpu_storage(&self) -> Result; fn affine(&self, _: &Layout, _: f64, _: f64) -> Result; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index a20d032d..e6e2e7a2 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,3 +1,4 @@ +use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOp, UnaryOp}; use crate::{DType, Error, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; @@ -14,6 +15,9 @@ pub enum CpuStorage { F64(Vec), } +#[derive(Debug, Clone)] +pub struct CpuDevice; + trait Map1 { fn f(&self, vs: &[T], layout: &Layout) -> Result>; @@ -519,7 +523,15 @@ fn elu(v: T, alpha: T) -> T { } impl CpuStorage { - pub fn dtype(&self) -> DType { + pub fn as_slice(&self) -> Result<&[D]> { + D::cpu_storage_as_slice(self) + } +} + +impl BackendStorage for CpuStorage { + type Device = CpuDevice; + + fn dtype(&self) -> DType { match self { Self::U8(_) => DType::U8, Self::U32(_) => DType::U32, @@ -530,11 +542,7 @@ impl CpuStorage { } } - pub fn as_slice(&self) -> Result<&[D]> { - D::cpu_storage_as_slice(self) - } - - pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { // TODO: find a way around the quadratic number of cases below. match (self, dtype) { (Self::U8(storage), DType::BF16) => { @@ -684,7 +692,7 @@ impl CpuStorage { } } - pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { + fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { let src_dims = layout.dims(); let mut dst_dims = src_dims.to_vec(); for &sum_dim in sum_dims.iter() { @@ -706,7 +714,7 @@ impl CpuStorage { .map(self, layout) } - pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { + fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { // [self] stores data in a contiguous way starting at offset 0. match self { Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim), @@ -717,11 +725,11 @@ impl CpuStorage { } } - pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { Affine(mul, add).map(self, layout) } - pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result { + fn elu(&self, layout: &Layout, alpha: f64) -> Result { // TODO: Have some generic map for functions that apply on num_traits::Float elements. match self { Self::BF16(storage) => { @@ -745,7 +753,7 @@ impl CpuStorage { } } - pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { + fn unary_impl(&self, layout: &Layout) -> Result { match self { Self::BF16(storage) => { let data = unary_map(storage, layout, B::bf16); @@ -774,12 +782,7 @@ impl CpuStorage { } } - pub(crate) fn binary_impl( - &self, - rhs: &Self, - lhs_l: &Layout, - rhs_l: &Layout, - ) -> Result { + fn binary_impl(&self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { match (self, rhs) { (Self::BF16(lhs), Self::BF16(rhs)) => { let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16); @@ -816,12 +819,7 @@ impl CpuStorage { } } - pub(crate) fn copy_strided_src( - &self, - dst: &mut Self, - dst_offset: usize, - src_l: &Layout, - ) -> Result<()> { + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { match (self, dst) { (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), @@ -841,7 +839,7 @@ impl CpuStorage { Ok(()) } - pub(crate) fn where_cond( + fn where_cond( &self, layout: &Layout, t: &Self, @@ -854,7 +852,7 @@ impl CpuStorage { WCond(pred, layout).map(t, t_l, f, f_l) } - pub(crate) fn conv1d( + fn conv1d( &self, l: &Layout, kernel: &Self, @@ -864,7 +862,7 @@ impl CpuStorage { Conv1D(params).map(self, l, kernel, kernel_l) } - pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { + fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let ids = self.as_slice::()?; let (vocab_size, hidden_size) = rhs_l.shape().r2()?; Embedding { @@ -876,7 +874,7 @@ impl CpuStorage { .map(rhs, rhs_l) } - pub(crate) fn matmul( + fn matmul( &self, rhs: &Self, bmnk: (usize, usize, usize, usize), @@ -886,7 +884,39 @@ impl CpuStorage { MatMul(bmnk).map(self, lhs_l, rhs, rhs_l) } - pub(crate) fn rand_uniform(shape: &Shape, dtype: DType, min: f64, max: f64) -> Result { + fn device(&self) -> &Self::Device { + &CpuDevice + } + + fn try_clone(&self, _: &Layout) -> Result { + Ok(self.clone()) + } + + fn to_cpu_storage(&self) -> Result { + Ok(self.clone()) + } +} + +impl BackendDevice for CpuDevice { + type Storage = CpuStorage; + + fn location(&self) -> crate::DeviceLocation { + crate::DeviceLocation::Cpu + } + + fn same_device(&self, _: &Self) -> bool { + true + } + + fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result { + Ok(s.clone()) + } + + fn new(_: usize) -> Result { + Ok(Self) + } + + fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result { use rand::prelude::*; let elem_count = shape.elem_count(); @@ -902,7 +932,7 @@ impl CpuStorage { for _i in 0..elem_count { data.push(rng.sample::(uniform)) } - Ok(Self::F32(data)) + Ok(CpuStorage::F32(data)) } DType::F64 => { let mut data = Vec::new(); @@ -911,12 +941,12 @@ impl CpuStorage { for _i in 0..elem_count { data.push(rng.sample::(uniform)) } - Ok(Self::F64(data)) + Ok(CpuStorage::F64(data)) } } } - pub(crate) fn rand_normal(shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { + fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { use rand::prelude::*; let elem_count = shape.elem_count(); @@ -933,7 +963,7 @@ impl CpuStorage { for _i in 0..elem_count { data.push(rng.sample::(rand::distributions::Standard) * std + mean) } - Ok(Self::F32(data)) + Ok(CpuStorage::F32(data)) } DType::F64 => { let mut data = Vec::new(); @@ -941,32 +971,34 @@ impl CpuStorage { for _i in 0..elem_count { data.push(rng.sample::(rand::distributions::Standard) * std + mean) } - Ok(Self::F64(data)) + Ok(CpuStorage::F64(data)) } } } - pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); - match dtype { - DType::U8 => Self::U8(vec![1u8; elem_count]), - DType::U32 => Self::U32(vec![1u32; elem_count]), - DType::BF16 => Self::BF16(vec![bf16::ONE; elem_count]), - DType::F16 => Self::F16(vec![f16::ONE; elem_count]), - DType::F32 => Self::F32(vec![1f32; elem_count]), - DType::F64 => Self::F64(vec![1f64; elem_count]), - } + let storage = match dtype { + DType::U8 => CpuStorage::U8(vec![1u8; elem_count]), + DType::U32 => CpuStorage::U32(vec![1u32; elem_count]), + DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), + DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]), + DType::F32 => CpuStorage::F32(vec![1f32; elem_count]), + DType::F64 => CpuStorage::F64(vec![1f64; elem_count]), + }; + Ok(storage) } - pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self { + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); - match dtype { - DType::U8 => Self::U8(vec![0u8; elem_count]), - DType::U32 => Self::U32(vec![0u32; elem_count]), - DType::BF16 => Self::BF16(vec![bf16::ZERO; elem_count]), - DType::F16 => Self::F16(vec![f16::ZERO; elem_count]), - DType::F32 => Self::F32(vec![0f32; elem_count]), - DType::F64 => Self::F64(vec![0f64; elem_count]), - } + let storage = match dtype { + DType::U8 => CpuStorage::U8(vec![0u8; elem_count]), + DType::U32 => CpuStorage::U32(vec![0u32; elem_count]), + DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), + DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), + DType::F32 => CpuStorage::F32(vec![0f32; elem_count]), + DType::F64 => CpuStorage::F64(vec![0f64; elem_count]), + }; + Ok(storage) } } diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index b428922b..ca408529 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -1,4 +1,5 @@ use crate::backend::BackendDevice; +use crate::cpu_backend::CpuDevice; use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; /// A `DeviceLocation` represents a physical device whereas multiple `Device` @@ -117,7 +118,7 @@ impl Device { ) -> Result { match self { Device::Cpu => { - let storage = CpuStorage::rand_uniform(shape, dtype, lo, up)?; + let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { @@ -136,7 +137,7 @@ impl Device { ) -> Result { match self { Device::Cpu => { - let storage = CpuStorage::rand_normal(shape, dtype, mean, std)?; + let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { @@ -149,7 +150,7 @@ impl Device { pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result { match self { Device::Cpu => { - let storage = CpuStorage::ones_impl(shape, dtype); + let storage = CpuDevice.ones_impl(shape, dtype)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { @@ -162,7 +163,7 @@ impl Device { pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result { match self { Device::Cpu => { - let storage = CpuStorage::zeros_impl(shape, dtype); + let storage = CpuDevice.zeros_impl(shape, dtype)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 8ce70f64..632abcc4 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -1,3 +1,4 @@ +use crate::backend::BackendStorage; use crate::{CpuStorage, Error, Result}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]