Tweaks to softmax. (#745)

This commit is contained in:
Laurent Mazare
2023-09-05 16:22:27 +02:00
committed by GitHub
parent 1c9e5394a5
commit 6615daf242
2 changed files with 84 additions and 19 deletions

View File

@ -1,4 +1,7 @@
pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy { pub trait VecOps: num_traits::NumAssign + Copy {
fn min(self, rhs: Self) -> Self;
fn max(self, rhs: Self) -> Self;
/// Dot-product of two vectors. /// Dot-product of two vectors.
/// ///
/// # Safety /// # Safety
@ -37,10 +40,7 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) { unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs; *res = *xs;
for i in 1..len { for i in 1..len {
let x = *xs.add(i); *res = (*res).max(*xs.add(i))
if x > *res {
*res = x
}
} }
} }
@ -54,15 +54,22 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) { unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs; *res = *xs;
for i in 1..len { for i in 1..len {
let x = *xs.add(i); *res = (*res).min(*xs.add(i))
if x < *res {
*res = x
}
} }
} }
} }
impl VecOps for f32 { impl VecOps for f32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
#[inline(always)] #[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
super::vec_dot_f32(lhs, rhs, res, len) super::vec_dot_f32(lhs, rhs, res, len)
@ -75,6 +82,16 @@ impl VecOps for f32 {
} }
impl VecOps for half::f16 { impl VecOps for half::f16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
#[inline(always)] #[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32; let mut res_f32 = 0f32;
@ -83,11 +100,61 @@ impl VecOps for half::f16 {
} }
} }
impl VecOps for f64 {} impl VecOps for f64 {
impl VecOps for half::bf16 {} #[inline(always)]
impl VecOps for u8 {} fn min(self, other: Self) -> Self {
impl VecOps for u32 {} Self::min(self, other)
impl VecOps for i64 {} }
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}
impl VecOps for half::bf16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}
impl VecOps for u8 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for u32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
#[inline(always)] #[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) { pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {

View File

@ -103,14 +103,12 @@ impl candle::CustomOp1 for SoftmaxLastDim {
.zip(dst.par_chunks_mut(dim_m1)) .zip(dst.par_chunks_mut(dim_m1))
.for_each(|(src, dst)| { .for_each(|(src, dst)| {
let mut max = T::neg_infinity(); let mut max = T::neg_infinity();
for &s in src.iter() { unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
max = T::max(s, max)
}
let mut sum_exp = T::zero();
for (s, d) in src.iter().zip(dst.iter_mut()) { for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = (*s - max).exp(); *d = (*s - max).exp();
sum_exp += *d
} }
let mut sum_exp = T::zero();
unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
for d in dst.iter_mut() { for d in dst.iter_mut() {
*d /= sum_exp *d /= sum_exp
} }