mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Optimize the sum for the contiguous case. (#192)
This commit is contained in:
@ -100,17 +100,46 @@ struct Sum<'a> {
|
||||
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_layout: &Layout) -> Result<Vec<T>> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
|
||||
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
|
||||
let mut dst_index = unstr_index;
|
||||
// Set the sum_dims indexes to 0.
|
||||
for &(dim, stride) in self.sum_dims_and_stride.iter() {
|
||||
// The compiler is able to optimize the following in a single divmod op.
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
let src = &src[o1..o2];
|
||||
// Handle the case where we sum over the last dimension separately as it is
|
||||
// fairly common and easy to optimize.
|
||||
if let [(sum_sz, 1)] = self.sum_dims_and_stride.as_slice() {
|
||||
let mut src_i = 0;
|
||||
for dst_v in dst.iter_mut() {
|
||||
for &s in src[src_i..src_i + sum_sz].iter() {
|
||||
*dst_v += s
|
||||
}
|
||||
src_i += sum_sz
|
||||
}
|
||||
return Ok(dst);
|
||||
};
|
||||
for (unstr_index, &src) in src.iter().enumerate() {
|
||||
let mut dst_index = unstr_index;
|
||||
// Set the sum_dims indexes to 0.
|
||||
for &(dim, stride) in self.sum_dims_and_stride.iter() {
|
||||
// The compiler is able to optimize the following in a single divmod op.
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] += src;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
for (unstr_index, src_index) in src_l.strided_index().enumerate() {
|
||||
let mut dst_index = unstr_index;
|
||||
// Set the sum_dims indexes to 0.
|
||||
for &(dim, stride) in self.sum_dims_and_stride.iter() {
|
||||
// The compiler is able to optimize the following in a single divmod op.
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] += src[src_index];
|
||||
}
|
||||
}
|
||||
dst[dst_index] += src[src_index];
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
|
Reference in New Issue
Block a user