Trace softmax (#568)

* Trace the softmax op.

* Inline the sum.

* Add min/max vec operations.
This commit is contained in:
Laurent Mazare
2023-08-23 15:25:50 +01:00
committed by GitHub
parent 075b505480
commit 329f661d9b
3 changed files with 47 additions and 9 deletions

View File

@ -1,4 +1,4 @@
pub trait VecOps: num_traits::NumAssign + Copy {
pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
/// Dot-product of two vectors.
///
/// # Safety
@ -26,6 +26,40 @@ pub trait VecOps: num_traits::NumAssign + Copy {
*res += *xs.add(i)
}
}
/// Maximum element in a non-empty vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
let x = *xs.add(i);
if x > *res {
*res = x
}
}
}
/// Minimum element in a non-empty vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
let x = *xs.add(i);
if x < *res {
*res = x
}
}
}
}
impl VecOps for f32 {