From bcfbb1dca1f3c75a18e820c9c817124d5767d1a2 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 23 Jun 2023 16:23:12 +0100 Subject: [PATCH] More efficient CPU broadcasting implementation. --- src/cpu_backend.rs | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index ca0352d8..292cd66a 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -23,10 +23,12 @@ fn unary_map T>(shape: &Shape, stride: &[usize], f: F) -> // This function maps over two strided index sequences. It supports broadcasting in case // `lhs_stride` or `rhs_stride` has a length shorter than `shape`. -fn binary_map T>( +fn binary_map T>( shape: &Shape, lhs_stride: &[usize], rhs_stride: &[usize], + lhs: &[T], + rhs: &[T], mut f: F, ) -> Vec { let dims = shape.dims(); @@ -35,21 +37,26 @@ fn binary_map T>( let elem_count = shape.elem_count(); if broadcast_ldims == 0 && broadcast_rdims == 0 { if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) { - (0..shape.elem_count()).map(|i| f((i, i))).collect() + (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() } else { let lhs_index = StridedIndex::new(dims, lhs_stride); let rhs_index = StridedIndex::new(dims, rhs_stride); - lhs_index.zip(rhs_index).map(f).collect() + lhs_index + .zip(rhs_index) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect() } } else if broadcast_rdims == 0 { let mut res = Vec::new(); res.reserve(elem_count); - let lhs_index: Vec<_> = StridedIndex::new(dims, lhs_stride).collect(); + let lhs_v: Vec = StridedIndex::new(dims, lhs_stride) + .map(|i| lhs[i]) + .collect(); let mut i = 0; for rhs_i in StridedIndex::new(dims, rhs_stride) { - res.push(f((lhs_index[i], rhs_i))); + res.push(f(lhs_v[i], rhs[rhs_i])); i += 1; - if i >= lhs_index.len() { + if i >= lhs_v.len() { i = 0 } } @@ -57,12 +64,14 @@ fn binary_map T>( } else if broadcast_ldims == 0 { let mut res = Vec::new(); res.reserve(elem_count); - let rhs_index: Vec<_> = StridedIndex::new(dims, rhs_stride).collect(); + let rhs_v: Vec = StridedIndex::new(dims, rhs_stride) + .map(|i| rhs[i]) + .collect(); let mut i = 0; for lhs_i in StridedIndex::new(dims, lhs_stride) { - res.push(f((lhs_i, rhs_index[i]))); + res.push(f(lhs[lhs_i], rhs_v[i])); i += 1; - if i >= rhs_index.len() { + if i >= rhs_v.len() { i = 0 } } @@ -144,21 +153,15 @@ impl CpuStorage { ) -> Result { match (self, rhs) { (Self::F32(lhs), Self::F32(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| { - B::f32(lhs[l], rhs[r]) - }); + let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32); Ok(Self::F32(data)) } (Self::F64(lhs), Self::F64(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| { - B::f64(lhs[l], rhs[r]) - }); + let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f64); Ok(Self::F64(data)) } (Self::U32(lhs), Self::U32(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| { - B::u32(lhs[l], rhs[r]) - }); + let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::u32); Ok(Self::U32(data)) } _ => {