Optimize the sum for the contiguous case. (#192)

This commit is contained in:
Laurent Mazare
2023-07-18 15:57:06 +02:00
committed by GitHub
parent 3307db204a
commit a45a3f0312

View File

@ -100,17 +100,46 @@ struct Sum<'a> {
impl<'a> Map1 for Sum<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_layout: &Layout) -> Result<Vec<T>> {
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
let mut dst_index = unstr_index;
// Set the sum_dims indexes to 0.
for &(dim, stride) in self.sum_dims_and_stride.iter() {
// The compiler is able to optimize the following in a single divmod op.
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
match src_l.contiguous_offsets() {
Some((o1, o2)) => {
let src = &src[o1..o2];
// Handle the case where we sum over the last dimension separately as it is
// fairly common and easy to optimize.
if let [(sum_sz, 1)] = self.sum_dims_and_stride.as_slice() {
let mut src_i = 0;
for dst_v in dst.iter_mut() {
for &s in src[src_i..src_i + sum_sz].iter() {
*dst_v += s
}
src_i += sum_sz
}
return Ok(dst);
};
for (unstr_index, &src) in src.iter().enumerate() {
let mut dst_index = unstr_index;
// Set the sum_dims indexes to 0.
for &(dim, stride) in self.sum_dims_and_stride.iter() {
// The compiler is able to optimize the following in a single divmod op.
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
}
dst[dst_index] += src;
}
}
None => {
for (unstr_index, src_index) in src_l.strided_index().enumerate() {
let mut dst_index = unstr_index;
// Set the sum_dims indexes to 0.
for &(dim, stride) in self.sum_dims_and_stride.iter() {
// The compiler is able to optimize the following in a single divmod op.
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
}
dst[dst_index] += src[src_index];
}
}
dst[dst_index] += src[src_index];
}
Ok(dst)
}