From a45a3f031259328e13663bdb76c61829726afa11 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 18 Jul 2023 15:57:06 +0200 Subject: [PATCH] Optimize the sum for the contiguous case. (#192) --- candle-core/src/cpu_backend.rs | 47 +++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index a466f88f..0d37774c 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -100,17 +100,46 @@ struct Sum<'a> { impl<'a> Map1 for Sum<'a> { #[inline(always)] - fn f(&self, src: &[T], src_layout: &Layout) -> Result> { + fn f(&self, src: &[T], src_l: &Layout) -> Result> { let mut dst = vec![T::zero(); self.dst_shape.elem_count()]; - for (unstr_index, src_index) in src_layout.strided_index().enumerate() { - let mut dst_index = unstr_index; - // Set the sum_dims indexes to 0. - for &(dim, stride) in self.sum_dims_and_stride.iter() { - // The compiler is able to optimize the following in a single divmod op. - let (pre, post) = (dst_index / stride, dst_index % stride); - dst_index = (pre / dim) * stride + post; + match src_l.contiguous_offsets() { + Some((o1, o2)) => { + let src = &src[o1..o2]; + // Handle the case where we sum over the last dimension separately as it is + // fairly common and easy to optimize. + if let [(sum_sz, 1)] = self.sum_dims_and_stride.as_slice() { + let mut src_i = 0; + for dst_v in dst.iter_mut() { + for &s in src[src_i..src_i + sum_sz].iter() { + *dst_v += s + } + src_i += sum_sz + } + return Ok(dst); + }; + for (unstr_index, &src) in src.iter().enumerate() { + let mut dst_index = unstr_index; + // Set the sum_dims indexes to 0. + for &(dim, stride) in self.sum_dims_and_stride.iter() { + // The compiler is able to optimize the following in a single divmod op. + let (pre, post) = (dst_index / stride, dst_index % stride); + dst_index = (pre / dim) * stride + post; + } + dst[dst_index] += src; + } + } + None => { + for (unstr_index, src_index) in src_l.strided_index().enumerate() { + let mut dst_index = unstr_index; + // Set the sum_dims indexes to 0. + for &(dim, stride) in self.sum_dims_and_stride.iter() { + // The compiler is able to optimize the following in a single divmod op. + let (pre, post) = (dst_index / stride, dst_index % stride); + dst_index = (pre / dim) * stride + post; + } + dst[dst_index] += src[src_index]; + } } - dst[dst_index] += src[src_index]; } Ok(dst) }