Sum over more dims. (#197)

This commit is contained in:
Laurent Mazare
2023-07-19 07:46:32 +02:00
committed by GitHub
parent 76dcc7a381
commit 67e20c3792

View File

@ -95,6 +95,7 @@ impl<'a> Map2 for WCond<'a> {
struct Sum<'a> {
dst_shape: &'a Shape,
sum_dims: &'a [usize],
sum_dims_and_stride: Vec<(usize, usize)>,
}
@ -105,9 +106,21 @@ impl<'a> Map1 for Sum<'a> {
match src_l.contiguous_offsets() {
Some((o1, o2)) => {
let src = &src[o1..o2];
// Handle the case where we sum over the last dimension separately as it is
// fairly common and easy to optimize.
if let [(sum_sz, 1)] = self.sum_dims_and_stride.as_slice() {
// Handle the case where we sum over the last dimensions separately as it is
// fairly common and easy to optimize. This rely on the layout being contiguous!
// sum_dims is sorted, check if it is ranging from a to n-1.
let sum_over_last_dims = self
.sum_dims
.iter()
.rev()
.enumerate()
.all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
if sum_over_last_dims {
let sum_sz = self
.sum_dims_and_stride
.iter()
.map(|(u, _)| u)
.product::<usize>();
let mut src_i = 0;
for dst_v in dst.iter_mut() {
for &s in src[src_i..src_i + sum_sz].iter() {
@ -1014,6 +1027,7 @@ impl BackendStorage for CpuStorage {
.collect();
Sum {
dst_shape: &dst_shape,
sum_dims: &sum_dims,
sum_dims_and_stride,
}
.map(self, layout)