From 83e75b3af8ddb058b6515c4a186af79208c8e592 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 23 Jun 2023 15:49:11 +0100 Subject: [PATCH] Optimize for the unstrided case. --- src/cpu_backend.rs | 57 +++++++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index b9b665d7..d85407f6 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -13,6 +13,28 @@ pub enum CpuStorage { F64(Vec), } +// This function maps over two strided index sequences. It supports broadcasting in case +// `lhs_stride` or `rhs_stride` has a length shorter than `shape`. +fn binary_map T>( + shape: &Shape, + lhs_stride: &[usize], + rhs_stride: &[usize], + mut f: F, +) -> Vec { + let dims = shape.dims(); + if dims.len() == lhs_stride.len() && dims.len() == rhs_stride.len() { + if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) { + (0..shape.elem_count()).map(|i| f((i, i))).collect() + } else { + let lhs_index = StridedIndex::new(dims, lhs_stride); + let rhs_index = StridedIndex::new(dims, rhs_stride); + lhs_index.zip(rhs_index).map(f).collect() + } + } else { + todo!("implement broadcast") + } +} + impl CpuStorage { pub fn dtype(&self) -> DType { match self { @@ -86,40 +108,23 @@ impl CpuStorage { lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { - let dims = shape.dims(); - if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() { - todo!("implement broadcast"); - } - // The ggml implementation has different paths based on whether the rhs is contiguous - // or not, for now we only consider the general case but we should benchmark and do the - // same if it helps. - // https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895 match (self, rhs) { (Self::F32(lhs), Self::F32(rhs)) => { - let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); - let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); - let data = lhs_index - .zip(rhs_index) - .map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i])) - .collect(); + let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| { + B::f32(lhs[l], rhs[r]) + }); Ok(Self::F32(data)) } (Self::F64(lhs), Self::F64(rhs)) => { - let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); - let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); - let data = lhs_index - .zip(rhs_index) - .map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i])) - .collect(); + let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| { + B::f64(lhs[l], rhs[r]) + }); Ok(Self::F64(data)) } (Self::U32(lhs), Self::U32(rhs)) => { - let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); - let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); - let data = lhs_index - .zip(rhs_index) - .map(|(lhs_i, rhs_i)| B::u32(lhs[lhs_i], rhs[rhs_i])) - .collect(); + let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| { + B::u32(lhs[l], rhs[r]) + }); Ok(Self::U32(data)) } _ => {