mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
This commit is contained in:
@ -21,8 +21,6 @@ pub trait BackendStorage: Sized {
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
|
||||
|
@ -90,7 +90,6 @@ impl Tensor {
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Softmax(node, _)
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _)
|
||||
| Op::CustomOp1(node, _) => {
|
||||
@ -324,7 +323,6 @@ impl Tensor {
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
||||
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
||||
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
|
||||
Op::Reshape(arg) => {
|
||||
let arg_grad = grad.reshape(arg.dims())?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -1236,45 +1236,6 @@ impl Map2 for MatMul {
|
||||
}
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way starting at offset 0.
|
||||
let dims = shape.dims();
|
||||
let elem_per_slice = dims[dim];
|
||||
let prod_pre_dim = dims[..dim].iter().product();
|
||||
let prod_post_dim = dims[dim + 1..].iter().product();
|
||||
if prod_post_dim == 1 {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
let mut sum = 0f64;
|
||||
let idx = pre_idx * elem_per_slice;
|
||||
for v in s[idx..idx + elem_per_slice].iter() {
|
||||
sum += v.to_f64();
|
||||
}
|
||||
let sum = T::from_f64(sum);
|
||||
for v in s[idx..idx + elem_per_slice].iter_mut() {
|
||||
*v /= sum
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
let mut sum = 0f64;
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
sum += s[idx].to_f64();
|
||||
idx += prod_post_dim
|
||||
}
|
||||
let sum = T::from_f64(sum);
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
s[idx] /= sum;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
|
||||
if v.is_sign_positive() {
|
||||
v
|
||||
@ -1513,17 +1474,6 @@ impl BackendStorage for CpuStorage {
|
||||
Cmp(op).map(self, lhs_l, rhs, rhs_l)
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way starting at offset 0.
|
||||
match self {
|
||||
Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::U8(_) | Self::U32(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
Affine(mul, add).map(self, layout)
|
||||
}
|
||||
|
@ -1303,10 +1303,6 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
|
||||
}
|
||||
|
||||
fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = U::V.map(&self.slice, &device, layout)?;
|
||||
|
@ -49,10 +49,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -93,7 +93,6 @@ pub enum Op {
|
||||
Broadcast(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
Reshape(Tensor),
|
||||
Softmax(Tensor, usize),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
Elu(Tensor, f64),
|
||||
|
@ -125,15 +125,6 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
// This assumes a contiguous layout and no offset.
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
|
@ -553,40 +553,6 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
|
||||
/// let a = a.softmax(1)?;
|
||||
/// assert_eq!(
|
||||
/// a.to_vec2::<f32>()?,
|
||||
/// &[
|
||||
/// [0.13447072, 0.3655293, 0.13447072, 0.3655293],
|
||||
/// [0.004892866, 0.26714143, 0.7261657, 0.0017999847],
|
||||
/// ]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "softmax")?;
|
||||
// TODO: unify the two branches.
|
||||
if self.device().is_cuda() {
|
||||
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||
// the operation.
|
||||
let exp = self.exp()?;
|
||||
let sum_exp = exp.sum_keepdim(dim)?;
|
||||
exp.broadcast_div(&sum_exp)
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?;
|
||||
// The resulting storage is contiguous.
|
||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||
let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim));
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
}
|
||||
|
||||
fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
|
||||
match dims {
|
||||
[] => Ok(self),
|
||||
|
@ -1,6 +1,5 @@
|
||||
mod test_utils;
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use test_utils::to_vec3_round;
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
@ -68,42 +67,6 @@ fn transpose(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn softmax(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let t0 = tensor.log()?.softmax(0)?;
|
||||
let t1 = tensor.log()?.softmax(1)?;
|
||||
let t2 = tensor.log()?.softmax(2)?;
|
||||
assert_eq!(
|
||||
to_vec3_round(t0, 4)?,
|
||||
&[
|
||||
// 3/5, 1/2, 4/11
|
||||
[[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]],
|
||||
// 2/5, 1/2, 7/11
|
||||
[[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
to_vec3_round(t1, 4)?,
|
||||
&[
|
||||
// 3/4, 1/6, 4/13
|
||||
[[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]],
|
||||
// 2/10, 1/3, 7/15
|
||||
[[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
to_vec3_round(t2, 4)?,
|
||||
&[
|
||||
// (3, 1, 4) / 8, (1, 5, 9) / 15
|
||||
[[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
|
||||
// (2, 1, 7) / 10, (8, 2, 8) / 18
|
||||
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sum(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -620,7 +583,6 @@ test_device!(cat, cat_cpu, cat_gpu);
|
||||
test_device!(sum, sum_cpu, sum_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
|
@ -333,7 +333,7 @@ impl BertSelfAttention {
|
||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||
let attention_probs = {
|
||||
let _enter_sm = self.span_softmax.enter();
|
||||
attention_scores.softmax(candle::D::Minus1)?
|
||||
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
|
||||
};
|
||||
let attention_probs = self.dropout.forward(&attention_probs)?;
|
||||
|
||||
|
@ -30,16 +30,6 @@ fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
|
||||
Ok(mask)
|
||||
}
|
||||
|
||||
// TODO: Use a numerically stable implementation by default.
|
||||
fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
||||
let d = d.to_index(xs.shape(), "log-softmax")?;
|
||||
let max = xs.max_keepdim(d)?;
|
||||
let diff = xs.broadcast_sub(&max)?;
|
||||
let num = diff.exp()?;
|
||||
let den = num.sum_keepdim(d)?;
|
||||
num.broadcast_div(&den)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
@ -192,7 +182,7 @@ impl Attention {
|
||||
let mask_value =
|
||||
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
|
||||
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
|
||||
let attn_weights = softmax(&attn_weights, D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
let value = value.contiguous()?;
|
||||
let attn_output = if self.multi_query {
|
||||
attn_weights
|
||||
|
@ -309,10 +309,12 @@ impl FalconAttention {
|
||||
|
||||
// Only handle the case where alibi is None here, and non-flash attention.
|
||||
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
|
||||
let attention_scores = attention_scores
|
||||
let attention_scores = candle_nn::ops::softmax(
|
||||
&attention_scores
|
||||
.broadcast_add(&mask.squeeze(1)?)?
|
||||
.to_dtype(DType::F32)?
|
||||
.softmax(D::Minus1)?
|
||||
.to_dtype(DType::F32)?,
|
||||
D::Minus1,
|
||||
)?
|
||||
.to_dtype(x.dtype())?;
|
||||
let attn_output = attention_scores
|
||||
.matmul(&value)?
|
||||
|
@ -233,7 +233,7 @@ impl CausalSelfAttention {
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(D::Minus1)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
|
||||
};
|
||||
|
@ -158,7 +158,7 @@ impl CausalSelfAttention {
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(D::Minus1)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
|
@ -323,7 +323,7 @@ impl CausalSelfAttention {
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(D::Minus1)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
|
@ -187,7 +187,7 @@ impl MusicgenAttention {
|
||||
let attn_weights = attn_weights
|
||||
.reshape((b_sz, self.num_heads, tgt_len, src_len))?
|
||||
.broadcast_add(attention_mask)?;
|
||||
let attn_weights = attn_weights.softmax(D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
// TODO: layer_head_mask?
|
||||
let attn_output = attn_weights
|
||||
.matmul(&value_states)?
|
||||
|
@ -223,7 +223,7 @@ impl T5Attention {
|
||||
.transpose(1, 2)?;
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
// TODO: position_bias_masked
|
||||
let attn_weights = scores.softmax(D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = self.o.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
|
@ -11,7 +11,7 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
@ -120,9 +120,7 @@ impl Decoder {
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
no_speech_prob = logits
|
||||
.get(0)?
|
||||
.softmax(0)?
|
||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
||||
.get(NO_SPEECH_TOKEN as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
@ -132,7 +130,7 @@ impl Decoder {
|
||||
.get(seq_len - 1)?
|
||||
.broadcast_add(&self.suppress_tokens)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = (&logits / t)?.softmax(0)?;
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
@ -146,8 +144,7 @@ impl Decoder {
|
||||
.unwrap()
|
||||
};
|
||||
tokens.push(next_token);
|
||||
let prob = logits
|
||||
.softmax(candle::D::Minus1)?
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.get(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
|
||||
|
@ -2,7 +2,7 @@
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
@ -154,7 +154,7 @@ impl MultiHeadAttention {
|
||||
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
|
||||
qk = qk.broadcast_add(&mask)?
|
||||
}
|
||||
let w = qk.softmax(candle::D::Minus1)?;
|
||||
let w = softmax(&qk, candle::D::Minus1)?;
|
||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
|
||||
Ok(wv)
|
||||
}
|
||||
|
@ -21,3 +21,4 @@ rayon = "1.7.0"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", features = ["cuda"] }
|
||||
|
@ -21,7 +21,7 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
|
||||
let k = k.to_dtype(DType::F32)?;
|
||||
let v = v.to_dtype(DType::F32)?;
|
||||
let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
|
||||
let att = att.softmax(D::Minus1)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
|
||||
Ok(output)
|
||||
|
@ -1,5 +1,29 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
|
||||
/// let a = candle_nn::ops::softmax(&a, 1)?;
|
||||
/// assert_eq!(
|
||||
/// a.to_vec2::<f32>()?,
|
||||
/// &[
|
||||
/// [0.13447072, 0.3655293, 0.13447072, 0.3655293],
|
||||
/// [0.0048928666, 0.26714146, 0.7261658, 0.0017999851]
|
||||
/// ]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
|
||||
let dim = dim.to_index(xs.shape(), "softmax")?;
|
||||
let max = xs.max_keepdim(dim)?;
|
||||
let diff = xs.broadcast_sub(&max)?;
|
||||
let num = diff.exp()?;
|
||||
let den = num.sum_keepdim(dim)?;
|
||||
num.broadcast_div(&den)
|
||||
}
|
||||
|
||||
pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
||||
let d = d.to_index(xs.shape(), "log-softmax")?;
|
||||
let max = xs.max_keepdim(d)?;
|
||||
|
62
candle-nn/tests/ops.rs
Normal file
62
candle-nn/tests/ops.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec3::<f32>()?;
|
||||
let t = t
|
||||
.iter()
|
||||
.map(|t| {
|
||||
t.iter()
|
||||
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax() -> Result<()> {
|
||||
let device = &Device::Cpu;
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?;
|
||||
let t1 = candle_nn::ops::softmax(&tensor.log()?, 1)?;
|
||||
let t2 = candle_nn::ops::softmax(&tensor.log()?, 2)?;
|
||||
assert_eq!(
|
||||
to_vec3_round(t0, 4)?,
|
||||
&[
|
||||
// 3/5, 1/2, 4/11
|
||||
[[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]],
|
||||
// 2/5, 1/2, 7/11
|
||||
[[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
to_vec3_round(t1, 4)?,
|
||||
&[
|
||||
// 3/4, 1/6, 4/13
|
||||
[[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]],
|
||||
// 2/10, 1/3, 7/15
|
||||
[[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
to_vec3_round(t2, 4)?,
|
||||
&[
|
||||
// (3, 1, 4) / 8, (1, 5, 9) / 15
|
||||
[[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
|
||||
// (2, 1, 7) / 10, (8, 2, 8) / 18
|
||||
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax_numerical_stability() -> Result<()> {
|
||||
let dev = &Device::Cpu;
|
||||
let xs = Tensor::new(&[1234f32, 0.], dev)?;
|
||||
let softmax = candle_nn::ops::softmax(&xs, 0)?;
|
||||
assert_eq!(softmax.to_vec1::<f32>()?, &[1f32, 0.]);
|
||||
Ok(())
|
||||
}
|
@ -17,7 +17,7 @@ impl LogitsProcessor {
|
||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let next_token = if let Some(temperature) = self.temperature {
|
||||
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
||||
let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?;
|
||||
let prs: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
|
@ -158,7 +158,7 @@ impl CausalSelfAttention {
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(D::Minus1)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::model::{Cache, Config, Llama};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
@ -88,7 +88,7 @@ impl LogitsProcessor {
|
||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let next_token = if let Some(temperature) = self.temperature {
|
||||
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
||||
let prs = softmax(&(&logits / temperature)?, D::Minus1)?;
|
||||
let prs: Vec<f32> = prs.to_vec1()?;
|
||||
let distr =
|
||||
rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?;
|
||||
|
@ -200,7 +200,7 @@ impl MultiHeadAttention {
|
||||
}
|
||||
let w = {
|
||||
let _timer = crate::Timer::new("qk::softmax");
|
||||
qk.softmax(candle::D::Minus1)?
|
||||
candle_nn::ops::softmax(&qk, candle::D::Minus1)?
|
||||
};
|
||||
let wv = {
|
||||
let _timer = crate::Timer::new("wv::matmul");
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::model::{Config, Whisper};
|
||||
use anyhow::Error as E;
|
||||
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -127,9 +127,7 @@ impl Decoder {
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
no_speech_prob = logits
|
||||
.get(0)?
|
||||
.softmax(0)?
|
||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
||||
.get(NO_SPEECH_TOKEN as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
@ -139,7 +137,7 @@ impl Decoder {
|
||||
.get(seq_len - 1)?
|
||||
.broadcast_add(&self.suppress_tokens)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = (&logits / t)?.softmax(0)?;
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(rng) as u32
|
||||
@ -153,8 +151,7 @@ impl Decoder {
|
||||
.unwrap()
|
||||
};
|
||||
tokens.push(next_token);
|
||||
let prob = logits
|
||||
.softmax(candle::D::Minus1)?
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.get(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
|
||||
|
Reference in New Issue
Block a user