Trace softmax (#568)

* Trace the softmax op.

* Inline the sum.

* Add min/max vec operations.
This commit is contained in:
Laurent Mazare
2023-08-23 15:25:50 +01:00
committed by GitHub
parent 075b505480
commit 329f661d9b
3 changed files with 47 additions and 9 deletions

View File

@ -1,4 +1,4 @@
pub trait VecOps: num_traits::NumAssign + Copy {
pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
/// Dot-product of two vectors.
///
/// # Safety
@ -26,6 +26,40 @@ pub trait VecOps: num_traits::NumAssign + Copy {
*res += *xs.add(i)
}
}
/// Maximum element in a non-empty vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
let x = *xs.add(i);
if x > *res {
*res = x
}
}
}
/// Minimum element in a non-empty vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
let x = *xs.add(i);
if x < *res {
*res = x
}
}
}
}
impl VecOps for f32 {

View File

@ -291,10 +291,9 @@ struct ReduceSum<'a> {
impl<'a> ReduceSum<'a> {
#[inline(always)]
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
where
T: WithDType,
F: Fn(T, T) -> T,
{
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
match src_l.contiguous_offsets() {
@ -335,7 +334,7 @@ impl<'a> ReduceSum<'a> {
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
}
dst[dst_index] = f(dst[dst_index], src);
dst[dst_index] += src;
}
}
None => {
@ -347,7 +346,7 @@ impl<'a> ReduceSum<'a> {
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
}
dst[dst_index] = f(dst[dst_index], src[src_index]);
dst[dst_index] += src[src_index];
}
}
}
@ -358,7 +357,7 @@ impl<'a> ReduceSum<'a> {
impl<'a> Map1 for ReduceSum<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
self.fold_impl(src, src_l, T::zero())
}
}

View File

@ -88,6 +88,7 @@ struct CrossAttention {
slice_size: Option<usize>,
span: tracing::Span,
span_attn: tracing::Span,
span_softmax: tracing::Span,
use_flash_attn: bool,
}
@ -111,6 +112,7 @@ impl CrossAttention {
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
let span = tracing::span!(tracing::Level::TRACE, "xa");
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
Ok(Self {
to_q,
to_k,
@ -121,6 +123,7 @@ impl CrossAttention {
slice_size,
span,
span_attn,
span_softmax,
use_flash_attn,
})
}
@ -193,9 +196,11 @@ impl CrossAttention {
let key = key.to_dtype(DType::F32)?;
let value = value.to_dtype(DType::F32)?;
let xs = query.matmul(&(key.t()? * self.scale)?)?;
nn::ops::softmax(&xs, D::Minus1)?
.matmul(&value)?
.to_dtype(in_dtype)?
let xs = {
let _enter = self.span_softmax.enter();
nn::ops::softmax(&xs, D::Minus1)?
};
xs.matmul(&value)?.to_dtype(in_dtype)?
};
self.reshape_batch_dim_to_heads(&xs)
}