diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 015a162d..6458b452 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -95,6 +95,7 @@ impl<'a> Map2 for WCond<'a> { struct Sum<'a> { dst_shape: &'a Shape, + sum_dims: &'a [usize], sum_dims_and_stride: Vec<(usize, usize)>, } @@ -105,9 +106,21 @@ impl<'a> Map1 for Sum<'a> { match src_l.contiguous_offsets() { Some((o1, o2)) => { let src = &src[o1..o2]; - // Handle the case where we sum over the last dimension separately as it is - // fairly common and easy to optimize. - if let [(sum_sz, 1)] = self.sum_dims_and_stride.as_slice() { + // Handle the case where we sum over the last dimensions separately as it is + // fairly common and easy to optimize. This rely on the layout being contiguous! + // sum_dims is sorted, check if it is ranging from a to n-1. + let sum_over_last_dims = self + .sum_dims + .iter() + .rev() + .enumerate() + .all(|(i, &v)| v == src_l.shape().rank() - 1 - i); + if sum_over_last_dims { + let sum_sz = self + .sum_dims_and_stride + .iter() + .map(|(u, _)| u) + .product::(); let mut src_i = 0; for dst_v in dst.iter_mut() { for &s in src[src_i..src_i + sum_sz].iter() { @@ -1014,6 +1027,7 @@ impl BackendStorage for CpuStorage { .collect(); Sum { dst_shape: &dst_shape, + sum_dims: &sum_dims, sum_dims_and_stride, } .map(self, layout)