mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Trace softmax (#568)
* Trace the softmax op. * Inline the sum. * Add min/max vec operations.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
pub trait VecOps: num_traits::NumAssign + Copy {
|
||||
pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
|
||||
/// Dot-product of two vectors.
|
||||
///
|
||||
/// # Safety
|
||||
@ -26,6 +26,40 @@ pub trait VecOps: num_traits::NumAssign + Copy {
|
||||
*res += *xs.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum element in a non-empty vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
let x = *xs.add(i);
|
||||
if x > *res {
|
||||
*res = x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimum element in a non-empty vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
let x = *xs.add(i);
|
||||
if x < *res {
|
||||
*res = x
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f32 {
|
||||
|
Reference in New Issue
Block a user