Add a softmax bench. (#433)

* Add a softmax bench.

* Add the vectorized sum reduce.
This commit is contained in:
Laurent Mazare
2023-08-13 21:09:18 +02:00
committed by GitHub
parent 9af438ac1b
commit d379a76a9e
3 changed files with 61 additions and 9 deletions

View File

@ -278,17 +278,17 @@ impl Map1Any for ReduceIndex {
}
}
struct Reduce<'a> {
struct ReduceSum<'a> {
dst_shape: &'a Shape,
reduce_dims: &'a [usize],
reduce_dims_and_stride: Vec<(usize, usize)>,
}
impl<'a> Reduce<'a> {
impl<'a> ReduceSum<'a> {
#[inline(always)]
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
where
T: Clone + Copy,
T: WithDType,
F: Fn(T, T) -> T,
{
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
@ -312,9 +312,13 @@ impl<'a> Reduce<'a> {
.product::<usize>();
for (dst_i, dst_v) in dst.iter_mut().enumerate() {
let src_i = dst_i * reduce_sz;
for &s in src[src_i..src_i + reduce_sz].iter() {
*dst_v = f(*dst_v, s)
}
unsafe {
T::vec_reduce_sum(
src[src_i..src_i + reduce_sz].as_ptr(),
dst_v,
reduce_sz,
)
};
}
return Ok(dst);
};
@ -346,7 +350,7 @@ impl<'a> Reduce<'a> {
}
}
impl<'a> Map1 for Reduce<'a> {
impl<'a> Map1 for ReduceSum<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
@ -1697,7 +1701,7 @@ impl BackendStorage for CpuStorage {
.iter()
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
.collect();
Reduce {
ReduceSum {
dst_shape: &dst_shape,
reduce_dims: &reduce_dims,
reduce_dims_and_stride,

View File

@ -12,6 +12,20 @@ pub trait VecDot: num_traits::NumAssign + Copy {
*res += *lhs.add(i) * *rhs.add(i)
}
}
/// Sum of all elements in a vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len`. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
*res = Self::zero();
for i in 0..len {
*res += *xs.add(i)
}
}
}
impl VecDot for f32 {
@ -19,6 +33,12 @@ impl VecDot for f32 {
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
ggblas::ggml::vec_dot_f32(lhs, rhs, res, len)
}
// TODO: enable the following once the updated ggblas is available.
// #[inline(always)]
// unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
// ggblas::ggml::vec_reduce_sum(xs, res, len)
// }
}
impl VecDot for f64 {}