mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Trace softmax (#568)
* Trace the softmax op. * Inline the sum. * Add min/max vec operations.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
pub trait VecOps: num_traits::NumAssign + Copy {
|
||||
pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
|
||||
/// Dot-product of two vectors.
|
||||
///
|
||||
/// # Safety
|
||||
@ -26,6 +26,40 @@ pub trait VecOps: num_traits::NumAssign + Copy {
|
||||
*res += *xs.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum element in a non-empty vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
let x = *xs.add(i);
|
||||
if x > *res {
|
||||
*res = x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Minimum element in a non-empty vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = *xs;
|
||||
for i in 1..len {
|
||||
let x = *xs.add(i);
|
||||
if x < *res {
|
||||
*res = x
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecOps for f32 {
|
||||
|
@ -291,10 +291,9 @@ struct ReduceSum<'a> {
|
||||
|
||||
impl<'a> ReduceSum<'a> {
|
||||
#[inline(always)]
|
||||
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
|
||||
fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
|
||||
where
|
||||
T: WithDType,
|
||||
F: Fn(T, T) -> T,
|
||||
{
|
||||
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
|
||||
match src_l.contiguous_offsets() {
|
||||
@ -335,7 +334,7 @@ impl<'a> ReduceSum<'a> {
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] = f(dst[dst_index], src);
|
||||
dst[dst_index] += src;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
@ -347,7 +346,7 @@ impl<'a> ReduceSum<'a> {
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] = f(dst[dst_index], src[src_index]);
|
||||
dst[dst_index] += src[src_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -358,7 +357,7 @@ impl<'a> ReduceSum<'a> {
|
||||
impl<'a> Map1 for ReduceSum<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
|
||||
self.fold_impl(src, src_l, T::zero())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,6 +88,7 @@ struct CrossAttention {
|
||||
slice_size: Option<usize>,
|
||||
span: tracing::Span,
|
||||
span_attn: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
@ -111,6 +112,7 @@ impl CrossAttention {
|
||||
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa");
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
|
||||
let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
|
||||
Ok(Self {
|
||||
to_q,
|
||||
to_k,
|
||||
@ -121,6 +123,7 @@ impl CrossAttention {
|
||||
slice_size,
|
||||
span,
|
||||
span_attn,
|
||||
span_softmax,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
@ -193,9 +196,11 @@ impl CrossAttention {
|
||||
let key = key.to_dtype(DType::F32)?;
|
||||
let value = value.to_dtype(DType::F32)?;
|
||||
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
||||
nn::ops::softmax(&xs, D::Minus1)?
|
||||
.matmul(&value)?
|
||||
.to_dtype(in_dtype)?
|
||||
let xs = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
nn::ops::softmax(&xs, D::Minus1)?
|
||||
};
|
||||
xs.matmul(&value)?.to_dtype(in_dtype)?
|
||||
};
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
|
Reference in New Issue
Block a user