mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add a softmax bench. (#433)
* Add a softmax bench. * Add the vectorized sum reduce.
This commit is contained in:
@ -278,17 +278,17 @@ impl Map1Any for ReduceIndex {
|
||||
}
|
||||
}
|
||||
|
||||
struct Reduce<'a> {
|
||||
struct ReduceSum<'a> {
|
||||
dst_shape: &'a Shape,
|
||||
reduce_dims: &'a [usize],
|
||||
reduce_dims_and_stride: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl<'a> Reduce<'a> {
|
||||
impl<'a> ReduceSum<'a> {
|
||||
#[inline(always)]
|
||||
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
|
||||
where
|
||||
T: Clone + Copy,
|
||||
T: WithDType,
|
||||
F: Fn(T, T) -> T,
|
||||
{
|
||||
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
|
||||
@ -312,9 +312,13 @@ impl<'a> Reduce<'a> {
|
||||
.product::<usize>();
|
||||
for (dst_i, dst_v) in dst.iter_mut().enumerate() {
|
||||
let src_i = dst_i * reduce_sz;
|
||||
for &s in src[src_i..src_i + reduce_sz].iter() {
|
||||
*dst_v = f(*dst_v, s)
|
||||
}
|
||||
unsafe {
|
||||
T::vec_reduce_sum(
|
||||
src[src_i..src_i + reduce_sz].as_ptr(),
|
||||
dst_v,
|
||||
reduce_sz,
|
||||
)
|
||||
};
|
||||
}
|
||||
return Ok(dst);
|
||||
};
|
||||
@ -346,7 +350,7 @@ impl<'a> Reduce<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Reduce<'a> {
|
||||
impl<'a> Map1 for ReduceSum<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
|
||||
@ -1697,7 +1701,7 @@ impl BackendStorage for CpuStorage {
|
||||
.iter()
|
||||
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
|
||||
.collect();
|
||||
Reduce {
|
||||
ReduceSum {
|
||||
dst_shape: &dst_shape,
|
||||
reduce_dims: &reduce_dims,
|
||||
reduce_dims_and_stride,
|
||||
|
Reference in New Issue
Block a user