mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Sum over more dims. (#197)
This commit is contained in:
@ -95,6 +95,7 @@ impl<'a> Map2 for WCond<'a> {
|
||||
|
||||
struct Sum<'a> {
|
||||
dst_shape: &'a Shape,
|
||||
sum_dims: &'a [usize],
|
||||
sum_dims_and_stride: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
@ -105,9 +106,21 @@ impl<'a> Map1 for Sum<'a> {
|
||||
match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
let src = &src[o1..o2];
|
||||
// Handle the case where we sum over the last dimension separately as it is
|
||||
// fairly common and easy to optimize.
|
||||
if let [(sum_sz, 1)] = self.sum_dims_and_stride.as_slice() {
|
||||
// Handle the case where we sum over the last dimensions separately as it is
|
||||
// fairly common and easy to optimize. This rely on the layout being contiguous!
|
||||
// sum_dims is sorted, check if it is ranging from a to n-1.
|
||||
let sum_over_last_dims = self
|
||||
.sum_dims
|
||||
.iter()
|
||||
.rev()
|
||||
.enumerate()
|
||||
.all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
|
||||
if sum_over_last_dims {
|
||||
let sum_sz = self
|
||||
.sum_dims_and_stride
|
||||
.iter()
|
||||
.map(|(u, _)| u)
|
||||
.product::<usize>();
|
||||
let mut src_i = 0;
|
||||
for dst_v in dst.iter_mut() {
|
||||
for &s in src[src_i..src_i + sum_sz].iter() {
|
||||
@ -1014,6 +1027,7 @@ impl BackendStorage for CpuStorage {
|
||||
.collect();
|
||||
Sum {
|
||||
dst_shape: &dst_shape,
|
||||
sum_dims: &sum_dims,
|
||||
sum_dims_and_stride,
|
||||
}
|
||||
.map(self, layout)
|
||||
|
Reference in New Issue
Block a user