mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Tweaks to softmax. (#745)
This commit is contained in:
@ -1,4 +1,7 @@
|
|||||||
pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
|
pub trait VecOps: num_traits::NumAssign + Copy {
|
||||||
|
fn min(self, rhs: Self) -> Self;
|
||||||
|
fn max(self, rhs: Self) -> Self;
|
||||||
|
|
||||||
/// Dot-product of two vectors.
|
/// Dot-product of two vectors.
|
||||||
///
|
///
|
||||||
/// # Safety
|
/// # Safety
|
||||||
@ -37,10 +40,7 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
|
|||||||
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
|
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
|
||||||
*res = *xs;
|
*res = *xs;
|
||||||
for i in 1..len {
|
for i in 1..len {
|
||||||
let x = *xs.add(i);
|
*res = (*res).max(*xs.add(i))
|
||||||
if x > *res {
|
|
||||||
*res = x
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,15 +54,22 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
|
|||||||
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
|
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
|
||||||
*res = *xs;
|
*res = *xs;
|
||||||
for i in 1..len {
|
for i in 1..len {
|
||||||
let x = *xs.add(i);
|
*res = (*res).min(*xs.add(i))
|
||||||
if x < *res {
|
|
||||||
*res = x
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VecOps for f32 {
|
impl VecOps for f32 {
|
||||||
|
#[inline(always)]
|
||||||
|
fn min(self, other: Self) -> Self {
|
||||||
|
Self::min(self, other)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn max(self, other: Self) -> Self {
|
||||||
|
Self::max(self, other)
|
||||||
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||||
super::vec_dot_f32(lhs, rhs, res, len)
|
super::vec_dot_f32(lhs, rhs, res, len)
|
||||||
@ -75,6 +82,16 @@ impl VecOps for f32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl VecOps for half::f16 {
|
impl VecOps for half::f16 {
|
||||||
|
#[inline(always)]
|
||||||
|
fn min(self, other: Self) -> Self {
|
||||||
|
Self::min(self, other)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn max(self, other: Self) -> Self {
|
||||||
|
Self::max(self, other)
|
||||||
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||||
let mut res_f32 = 0f32;
|
let mut res_f32 = 0f32;
|
||||||
@ -83,11 +100,61 @@ impl VecOps for half::f16 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VecOps for f64 {}
|
impl VecOps for f64 {
|
||||||
impl VecOps for half::bf16 {}
|
#[inline(always)]
|
||||||
impl VecOps for u8 {}
|
fn min(self, other: Self) -> Self {
|
||||||
impl VecOps for u32 {}
|
Self::min(self, other)
|
||||||
impl VecOps for i64 {}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn max(self, other: Self) -> Self {
|
||||||
|
Self::max(self, other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl VecOps for half::bf16 {
|
||||||
|
#[inline(always)]
|
||||||
|
fn min(self, other: Self) -> Self {
|
||||||
|
Self::min(self, other)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn max(self, other: Self) -> Self {
|
||||||
|
Self::max(self, other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl VecOps for u8 {
|
||||||
|
#[inline(always)]
|
||||||
|
fn min(self, other: Self) -> Self {
|
||||||
|
<Self as Ord>::min(self, other)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn max(self, other: Self) -> Self {
|
||||||
|
<Self as Ord>::max(self, other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl VecOps for u32 {
|
||||||
|
#[inline(always)]
|
||||||
|
fn min(self, other: Self) -> Self {
|
||||||
|
<Self as Ord>::min(self, other)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn max(self, other: Self) -> Self {
|
||||||
|
<Self as Ord>::max(self, other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl VecOps for i64 {
|
||||||
|
#[inline(always)]
|
||||||
|
fn min(self, other: Self) -> Self {
|
||||||
|
<Self as Ord>::min(self, other)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn max(self, other: Self) -> Self {
|
||||||
|
<Self as Ord>::max(self, other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||||
|
@ -103,14 +103,12 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
.zip(dst.par_chunks_mut(dim_m1))
|
.zip(dst.par_chunks_mut(dim_m1))
|
||||||
.for_each(|(src, dst)| {
|
.for_each(|(src, dst)| {
|
||||||
let mut max = T::neg_infinity();
|
let mut max = T::neg_infinity();
|
||||||
for &s in src.iter() {
|
unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
|
||||||
max = T::max(s, max)
|
|
||||||
}
|
|
||||||
let mut sum_exp = T::zero();
|
|
||||||
for (s, d) in src.iter().zip(dst.iter_mut()) {
|
for (s, d) in src.iter().zip(dst.iter_mut()) {
|
||||||
*d = (*s - max).exp();
|
*d = (*s - max).exp();
|
||||||
sum_exp += *d
|
|
||||||
}
|
}
|
||||||
|
let mut sum_exp = T::zero();
|
||||||
|
unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
|
||||||
for d in dst.iter_mut() {
|
for d in dst.iter_mut() {
|
||||||
*d /= sum_exp
|
*d /= sum_exp
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user