mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Sum over more dims. (#197)
This commit is contained in:
@ -95,6 +95,7 @@ impl<'a> Map2 for WCond<'a> {
|
|||||||
|
|
||||||
struct Sum<'a> {
|
struct Sum<'a> {
|
||||||
dst_shape: &'a Shape,
|
dst_shape: &'a Shape,
|
||||||
|
sum_dims: &'a [usize],
|
||||||
sum_dims_and_stride: Vec<(usize, usize)>,
|
sum_dims_and_stride: Vec<(usize, usize)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,9 +106,21 @@ impl<'a> Map1 for Sum<'a> {
|
|||||||
match src_l.contiguous_offsets() {
|
match src_l.contiguous_offsets() {
|
||||||
Some((o1, o2)) => {
|
Some((o1, o2)) => {
|
||||||
let src = &src[o1..o2];
|
let src = &src[o1..o2];
|
||||||
// Handle the case where we sum over the last dimension separately as it is
|
// Handle the case where we sum over the last dimensions separately as it is
|
||||||
// fairly common and easy to optimize.
|
// fairly common and easy to optimize. This rely on the layout being contiguous!
|
||||||
if let [(sum_sz, 1)] = self.sum_dims_and_stride.as_slice() {
|
// sum_dims is sorted, check if it is ranging from a to n-1.
|
||||||
|
let sum_over_last_dims = self
|
||||||
|
.sum_dims
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.enumerate()
|
||||||
|
.all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
|
||||||
|
if sum_over_last_dims {
|
||||||
|
let sum_sz = self
|
||||||
|
.sum_dims_and_stride
|
||||||
|
.iter()
|
||||||
|
.map(|(u, _)| u)
|
||||||
|
.product::<usize>();
|
||||||
let mut src_i = 0;
|
let mut src_i = 0;
|
||||||
for dst_v in dst.iter_mut() {
|
for dst_v in dst.iter_mut() {
|
||||||
for &s in src[src_i..src_i + sum_sz].iter() {
|
for &s in src[src_i..src_i + sum_sz].iter() {
|
||||||
@ -1014,6 +1027,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
.collect();
|
.collect();
|
||||||
Sum {
|
Sum {
|
||||||
dst_shape: &dst_shape,
|
dst_shape: &dst_shape,
|
||||||
|
sum_dims: &sum_dims,
|
||||||
sum_dims_and_stride,
|
sum_dims_and_stride,
|
||||||
}
|
}
|
||||||
.map(self, layout)
|
.map(self, layout)
|
||||||
|
Reference in New Issue
Block a user