mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Optimize the sum for the contiguous case. (#192)
This commit is contained in:
@ -100,17 +100,46 @@ struct Sum<'a> {
|
|||||||
|
|
||||||
impl<'a> Map1 for Sum<'a> {
|
impl<'a> Map1 for Sum<'a> {
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn f<T: WithDType>(&self, src: &[T], src_layout: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||||
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
|
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
|
||||||
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
|
match src_l.contiguous_offsets() {
|
||||||
let mut dst_index = unstr_index;
|
Some((o1, o2)) => {
|
||||||
// Set the sum_dims indexes to 0.
|
let src = &src[o1..o2];
|
||||||
for &(dim, stride) in self.sum_dims_and_stride.iter() {
|
// Handle the case where we sum over the last dimension separately as it is
|
||||||
// The compiler is able to optimize the following in a single divmod op.
|
// fairly common and easy to optimize.
|
||||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
if let [(sum_sz, 1)] = self.sum_dims_and_stride.as_slice() {
|
||||||
dst_index = (pre / dim) * stride + post;
|
let mut src_i = 0;
|
||||||
|
for dst_v in dst.iter_mut() {
|
||||||
|
for &s in src[src_i..src_i + sum_sz].iter() {
|
||||||
|
*dst_v += s
|
||||||
|
}
|
||||||
|
src_i += sum_sz
|
||||||
|
}
|
||||||
|
return Ok(dst);
|
||||||
|
};
|
||||||
|
for (unstr_index, &src) in src.iter().enumerate() {
|
||||||
|
let mut dst_index = unstr_index;
|
||||||
|
// Set the sum_dims indexes to 0.
|
||||||
|
for &(dim, stride) in self.sum_dims_and_stride.iter() {
|
||||||
|
// The compiler is able to optimize the following in a single divmod op.
|
||||||
|
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||||
|
dst_index = (pre / dim) * stride + post;
|
||||||
|
}
|
||||||
|
dst[dst_index] += src;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
for (unstr_index, src_index) in src_l.strided_index().enumerate() {
|
||||||
|
let mut dst_index = unstr_index;
|
||||||
|
// Set the sum_dims indexes to 0.
|
||||||
|
for &(dim, stride) in self.sum_dims_and_stride.iter() {
|
||||||
|
// The compiler is able to optimize the following in a single divmod op.
|
||||||
|
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||||
|
dst_index = (pre / dim) * stride + post;
|
||||||
|
}
|
||||||
|
dst[dst_index] += src[src_index];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
dst[dst_index] += src[src_index];
|
|
||||||
}
|
}
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user