mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Trace softmax (#568)
* Trace the softmax op. * Inline the sum. * Add min/max vec operations.
This commit is contained in:
@ -291,10 +291,9 @@ struct ReduceSum<'a> {
|
||||
|
||||
impl<'a> ReduceSum<'a> {
|
||||
#[inline(always)]
|
||||
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
|
||||
fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
|
||||
where
|
||||
T: WithDType,
|
||||
F: Fn(T, T) -> T,
|
||||
{
|
||||
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
|
||||
match src_l.contiguous_offsets() {
|
||||
@ -335,7 +334,7 @@ impl<'a> ReduceSum<'a> {
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] = f(dst[dst_index], src);
|
||||
dst[dst_index] += src;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
@ -347,7 +346,7 @@ impl<'a> ReduceSum<'a> {
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] = f(dst[dst_index], src[src_index]);
|
||||
dst[dst_index] += src[src_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -358,7 +357,7 @@ impl<'a> ReduceSum<'a> {
|
||||
impl<'a> Map1 for ReduceSum<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
|
||||
self.fold_impl(src, src_l, T::zero())
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user