From eb52b9b343819c547b8e4c47a8ff70cb7c632fbb Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 21 Jun 2023 10:25:56 +0100 Subject: [PATCH] Move the cpu backend specific bits apart. --- src/cpu_backend.rs | 99 ++++++++++++++++++++++++++++++++++++++++++++++ src/device.rs | 5 +-- src/lib.rs | 4 +- src/storage.rs | 93 +++++++------------------------------------ 4 files changed, 118 insertions(+), 83 deletions(-) create mode 100644 src/cpu_backend.rs diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs new file mode 100644 index 00000000..03068866 --- /dev/null +++ b/src/cpu_backend.rs @@ -0,0 +1,99 @@ +use crate::storage::{BinaryOp, UnaryOp}; +use crate::{DType, Error, Result, Shape, StridedIndex}; + +// TODO: Think about whether we would be better off with a dtype and +// a buffer as an owned slice of bytes. +#[derive(Debug, Clone)] +pub enum CpuStorage { + F32(Vec), + F64(Vec), +} + +impl CpuStorage { + pub fn dtype(&self) -> DType { + match self { + Self::F32(_) => DType::F32, + Self::F64(_) => DType::F64, + } + } + + pub(crate) fn affine_impl( + &self, + shape: &Shape, + stride: &[usize], + mul: f64, + add: f64, + ) -> Result { + match self { + Self::F32(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let mul = mul as f32; + let add = add as f32; + let data = index.map(|i| storage[i] * mul + add).collect(); + Ok(Self::F32(data)) + } + Self::F64(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let data = index.map(|i| storage[i] * mul + add).collect(); + Ok(Self::F64(data)) + } + } + } + + pub(crate) fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { + // TODO: Different code path for the contiguous case? + match self { + Self::F32(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let data = index.map(|i| B::f32(storage[i])).collect(); + Ok(Self::F32(data)) + } + Self::F64(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let data = index.map(|i| B::f64(storage[i])).collect(); + Ok(Self::F64(data)) + } + } + } + + pub(crate) fn binary_impl( + &self, + rhs: &Self, + shape: &Shape, + lhs_stride: &[usize], + rhs_stride: &[usize], + ) -> Result { + // The ggml implementation has different paths based on whether the rhs is contiguous + // or not, for now we only consider the general case but we should benchmark and do the + // same if it helps. + // https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895 + match (self, rhs) { + (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { + let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); + let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); + let data = lhs_index + .zip(rhs_index) + .map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i])) + .collect(); + Ok(Self::F32(data)) + } + (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { + let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); + let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); + let data = lhs_index + .zip(rhs_index) + .map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i])) + .collect(); + Ok(Self::F64(data)) + } + _ => { + // This should be covered by the dtype check above. + Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: rhs.dtype(), + op: B::NAME, + }) + } + } + } +} diff --git a/src/device.rs b/src/device.rs index c092a347..3677cfff 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,7 +1,4 @@ -use crate::{ - storage::{CpuStorage, Storage}, - DType, Result, Shape, -}; +use crate::{CpuStorage, DType, Result, Shape, Storage}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum Device { diff --git a/src/lib.rs b/src/lib.rs index 58c2ba52..175d36ad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod cpu_backend; mod device; mod dtype; mod error; @@ -7,10 +8,11 @@ mod storage; mod strided_index; mod tensor; +pub use cpu_backend::CpuStorage; pub use device::Device; pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; pub use shape::Shape; -pub use storage::{CpuStorage, Storage}; +pub use storage::Storage; use strided_index::StridedIndex; pub use tensor::{Tensor, TensorId}; diff --git a/src/storage.rs b/src/storage.rs index 30161a2c..7083cc28 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,21 +1,4 @@ -use crate::{DType, Device, Error, Result, Shape, StridedIndex}; - -// TODO: Think about whether we would be better off with a dtype and -// a buffer as an owned slice of bytes. -#[derive(Debug, Clone)] -pub enum CpuStorage { - F32(Vec), - F64(Vec), -} - -impl CpuStorage { - pub(crate) fn dtype(&self) -> DType { - match self { - Self::F32(_) => DType::F32, - Self::F64(_) => DType::F64, - } - } -} +use crate::{CpuStorage, DType, Device, Error, Result, Shape}; #[derive(Debug, Clone)] pub enum Storage { @@ -23,13 +6,13 @@ pub enum Storage { Cuda { gpu_id: usize }, // TODO: Actually add the storage. } -trait UnaryOp { +pub(crate) trait UnaryOp { const NAME: &'static str; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; } -trait BinaryOp { +pub(crate) trait BinaryOp { const NAME: &'static str; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; @@ -157,20 +140,10 @@ impl Storage { ) -> Result { // TODO: Different code path for the contiguous case? match self { - Storage::Cpu(storage) => match storage { - CpuStorage::F32(storage) => { - let index = StridedIndex::new(shape.dims(), stride); - let mul = mul as f32; - let add = add as f32; - let data = index.map(|i| storage[i] * mul + add).collect(); - Ok(Storage::Cpu(CpuStorage::F32(data))) - } - CpuStorage::F64(storage) => { - let index = StridedIndex::new(shape.dims(), stride); - let data = index.map(|i| storage[i] * mul + add).collect(); - Ok(Storage::Cpu(CpuStorage::F64(data))) - } - }, + Storage::Cpu(storage) => { + let storage = storage.affine_impl(shape, stride, mul, add)?; + Ok(Self::Cpu(storage)) + } Self::Cuda { .. } => todo!(), } } @@ -178,18 +151,10 @@ impl Storage { fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { // TODO: Different code path for the contiguous case? match self { - Storage::Cpu(storage) => match storage { - CpuStorage::F32(storage) => { - let index = StridedIndex::new(shape.dims(), stride); - let data = index.map(|i| B::f32(storage[i])).collect(); - Ok(Storage::Cpu(CpuStorage::F32(data))) - } - CpuStorage::F64(storage) => { - let index = StridedIndex::new(shape.dims(), stride); - let data = index.map(|i| B::f64(storage[i])).collect(); - Ok(Storage::Cpu(CpuStorage::F64(data))) - } - }, + Storage::Cpu(storage) => { + let storage = storage.unary_impl::(shape, stride)?; + Ok(Self::Cpu(storage)) + } Self::Cuda { .. } => todo!(), } } @@ -204,39 +169,11 @@ impl Storage { ) -> Result { self.same_device(rhs, B::NAME)?; self.same_dtype(rhs, B::NAME)?; - // The ggml implementation has different paths based on whether the rhs is contiguous - // or not, for now we only consider the general case but we should benchmark and do the - // same if it helps. - // https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895 match (self, rhs) { - (Storage::Cpu(lhs), Storage::Cpu(rhs)) => match (lhs, rhs) { - (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { - let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); - let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); - let data = lhs_index - .zip(rhs_index) - .map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i])) - .collect(); - Ok(Storage::Cpu(CpuStorage::F32(data))) - } - (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { - let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); - let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); - let data = lhs_index - .zip(rhs_index) - .map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i])) - .collect(); - Ok(Storage::Cpu(CpuStorage::F64(data))) - } - _ => { - // This should be covered by the dtype check above. - Err(Error::DTypeMismatchBinaryOp { - lhs: lhs.dtype(), - rhs: rhs.dtype(), - op: B::NAME, - }) - } - }, + (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { + let storage = lhs.binary_impl::(rhs, shape, lhs_stride, rhs_stride)?; + Ok(Self::Cpu(storage)) + } (Self::Cuda { .. }, Self::Cuda { .. }) => todo!(), (lhs, rhs) => { // Should not happen because of the same device check above but we're defensive