Softmax numerical stability. (#267)

* Softmax numerical stability.

* Fix the flash-attn test.
This commit is contained in:
Laurent Mazare
2023-07-28 13:13:01 +01:00
committed by GitHub
parent 68eab38de6
commit 3eb2bc6d07
28 changed files with 117 additions and 188 deletions

View File

@ -21,8 +21,6 @@ pub trait BackendStorage: Sized {
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;

View File

@ -90,7 +90,6 @@ impl Tensor {
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Narrow(node, _, _, _)
| Op::Softmax(node, _)
| Op::Unary(node, _)
| Op::Elu(node, _)
| Op::CustomOp1(node, _) => {
@ -324,7 +323,6 @@ impl Tensor {
}
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;

View File

@ -1236,45 +1236,6 @@ impl Map2 for MatMul {
}
}
fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> {
// [self] stores data in a contiguous way starting at offset 0.
let dims = shape.dims();
let elem_per_slice = dims[dim];
let prod_pre_dim = dims[..dim].iter().product();
let prod_post_dim = dims[dim + 1..].iter().product();
if prod_post_dim == 1 {
for pre_idx in 0..prod_pre_dim {
let mut sum = 0f64;
let idx = pre_idx * elem_per_slice;
for v in s[idx..idx + elem_per_slice].iter() {
sum += v.to_f64();
}
let sum = T::from_f64(sum);
for v in s[idx..idx + elem_per_slice].iter_mut() {
*v /= sum
}
}
} else {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += s[idx].to_f64();
idx += prod_post_dim
}
let sum = T::from_f64(sum);
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
s[idx] /= sum;
idx += prod_post_dim
}
}
}
}
Ok(())
}
fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
if v.is_sign_positive() {
v
@ -1513,17 +1474,6 @@ impl BackendStorage for CpuStorage {
Cmp(op).map(self, lhs_l, rhs, rhs_l)
}
fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
// [self] stores data in a contiguous way starting at offset 0.
match self {
Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim),
Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
Self::U8(_) | Self::U32(_) => Ok(()),
}
}
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
Affine(mul, add).map(self, layout)
}

View File

@ -1303,10 +1303,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
}
fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = U::V.map(&self.slice, &device, layout)?;

View File

@ -49,10 +49,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -93,7 +93,6 @@ pub enum Op {
Broadcast(Tensor),
Narrow(Tensor, usize, usize, usize),
Reshape(Tensor),
Softmax(Tensor, usize),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Elu(Tensor, f64),

View File

@ -125,15 +125,6 @@ impl Storage {
}
}
// This assumes a contiguous layout and no offset.
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
match self {
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
}
Ok(())
}
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
match self {
Storage::Cpu(storage) => {

View File

@ -553,40 +553,6 @@ impl Tensor {
}
}
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
///
/// ```rust
/// use candle::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
/// let a = a.softmax(1)?;
/// assert_eq!(
/// a.to_vec2::<f32>()?,
/// &[
/// [0.13447072, 0.3655293, 0.13447072, 0.3655293],
/// [0.004892866, 0.26714143, 0.7261657, 0.0017999847],
/// ]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "softmax")?;
// TODO: unify the two branches.
if self.device().is_cuda() {
// We do not have a cuda kernel for divide_by_sum_over_dim so split
// the operation.
let exp = self.exp()?;
let sum_exp = exp.sum_keepdim(dim)?;
exp.broadcast_div(&sum_exp)
} else {
let shape = self.shape();
let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?;
// The resulting storage is contiguous.
storage.divide_by_sum_over_dim(shape, dim)?;
let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim));
Ok(from_storage(storage, shape.clone(), op, false))
}
}
fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
match dims {
[] => Ok(self),

View File

@ -1,6 +1,5 @@
mod test_utils;
use candle::{DType, Device, IndexOp, Result, Tensor};
use test_utils::to_vec3_round;
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@ -68,42 +67,6 @@ fn transpose(device: &Device) -> Result<()> {
Ok(())
}
fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
let t0 = tensor.log()?.softmax(0)?;
let t1 = tensor.log()?.softmax(1)?;
let t2 = tensor.log()?.softmax(2)?;
assert_eq!(
to_vec3_round(t0, 4)?,
&[
// 3/5, 1/2, 4/11
[[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]],
// 2/5, 1/2, 7/11
[[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]]
]
);
assert_eq!(
to_vec3_round(t1, 4)?,
&[
// 3/4, 1/6, 4/13
[[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]],
// 2/10, 1/3, 7/15
[[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]]
]
);
assert_eq!(
to_vec3_round(t2, 4)?,
&[
// (3, 1, 4) / 8, (1, 5, 9) / 15
[[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
// (2, 1, 7) / 10, (8, 2, 8) / 18
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
]
);
Ok(())
}
fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
@ -620,7 +583,6 @@ test_device!(cat, cat_cpu, cat_gpu);
test_device!(sum, sum_cpu, sum_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(softmax, softmax_cpu, softmax_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);

View File

@ -333,7 +333,7 @@ impl BertSelfAttention {
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_probs = {
let _enter_sm = self.span_softmax.enter();
attention_scores.softmax(candle::D::Minus1)?
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
};
let attention_probs = self.dropout.forward(&attention_probs)?;

View File

@ -30,16 +30,6 @@ fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
Ok(mask)
}
// TODO: Use a numerically stable implementation by default.
fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
let d = d.to_index(xs.shape(), "log-softmax")?;
let max = xs.max_keepdim(d)?;
let diff = xs.broadcast_sub(&max)?;
let num = diff.exp()?;
let den = num.sum_keepdim(d)?;
num.broadcast_div(&den)
}
#[derive(Debug)]
pub struct Config {
pub vocab_size: usize,
@ -192,7 +182,7 @@ impl Attention {
let mask_value =
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
let attn_weights = softmax(&attn_weights, D::Minus1)?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let value = value.contiguous()?;
let attn_output = if self.multi_query {
attn_weights

View File

@ -309,10 +309,12 @@ impl FalconAttention {
// Only handle the case where alibi is None here, and non-flash attention.
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
let attention_scores = attention_scores
let attention_scores = candle_nn::ops::softmax(
&attention_scores
.broadcast_add(&mask.squeeze(1)?)?
.to_dtype(DType::F32)?
.softmax(D::Minus1)?
.to_dtype(DType::F32)?,
D::Minus1,
)?
.to_dtype(x.dtype())?;
let attn_output = attention_scores
.matmul(&value)?

View File

@ -233,7 +233,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};

View File

@ -158,7 +158,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;

View File

@ -323,7 +323,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;

View File

@ -187,7 +187,7 @@ impl MusicgenAttention {
let attn_weights = attn_weights
.reshape((b_sz, self.num_heads, tgt_len, src_len))?
.broadcast_add(attention_mask)?;
let attn_weights = attn_weights.softmax(D::Minus1)?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
// TODO: layer_head_mask?
let attn_output = attn_weights
.matmul(&value_states)?

View File

@ -223,7 +223,7 @@ impl T5Attention {
.transpose(1, 2)?;
let scores = q.matmul(&k.t()?)?;
// TODO: position_bias_masked
let attn_weights = scores.softmax(D::Minus1)?;
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = self.o.forward(&attn_output)?;
Ok(attn_output)

View File

@ -11,7 +11,7 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{safetensors::Load, DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
@ -120,9 +120,7 @@ impl Decoder {
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
no_speech_prob = logits
.get(0)?
.softmax(0)?
no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)?
.to_scalar::<f32>()? as f64;
}
@ -132,7 +130,7 @@ impl Decoder {
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = (&logits / t)?.softmax(0)?;
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
@ -146,8 +144,7 @@ impl Decoder {
.unwrap()
};
tokens.push(next_token);
let prob = logits
.softmax(candle::D::Minus1)?
let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {

View File

@ -2,7 +2,7 @@
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@ -154,7 +154,7 @@ impl MultiHeadAttention {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
let w = qk.softmax(candle::D::Minus1)?;
let w = softmax(&qk, candle::D::Minus1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
Ok(wv)
}

View File

@ -21,3 +21,4 @@ rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", features = ["cuda"] }

View File

@ -21,7 +21,7 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
let att = att.softmax(D::Minus1)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
Ok(output)

View File

@ -1,5 +1,29 @@
use candle::{Result, Tensor};
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
///
/// ```rust
/// use candle::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
/// let a = candle_nn::ops::softmax(&a, 1)?;
/// assert_eq!(
/// a.to_vec2::<f32>()?,
/// &[
/// [0.13447072, 0.3655293, 0.13447072, 0.3655293],
/// [0.0048928666, 0.26714146, 0.7261658, 0.0017999851]
/// ]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
let dim = dim.to_index(xs.shape(), "softmax")?;
let max = xs.max_keepdim(dim)?;
let diff = xs.broadcast_sub(&max)?;
let num = diff.exp()?;
let den = num.sum_keepdim(dim)?;
num.broadcast_div(&den)
}
pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
let d = d.to_index(xs.shape(), "log-softmax")?;
let max = xs.max_keepdim(d)?;

62
candle-nn/tests/ops.rs Normal file
View File

@ -0,0 +1,62 @@
use candle::{Device, Result, Tensor};
pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
let b = 10f32.powi(digits);
let t = t.to_vec3::<f32>()?;
let t = t
.iter()
.map(|t| {
t.iter()
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
.collect()
})
.collect();
Ok(t)
}
#[test]
fn softmax() -> Result<()> {
let device = &Device::Cpu;
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?;
let t1 = candle_nn::ops::softmax(&tensor.log()?, 1)?;
let t2 = candle_nn::ops::softmax(&tensor.log()?, 2)?;
assert_eq!(
to_vec3_round(t0, 4)?,
&[
// 3/5, 1/2, 4/11
[[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]],
// 2/5, 1/2, 7/11
[[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]]
]
);
assert_eq!(
to_vec3_round(t1, 4)?,
&[
// 3/4, 1/6, 4/13
[[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]],
// 2/10, 1/3, 7/15
[[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]]
]
);
assert_eq!(
to_vec3_round(t2, 4)?,
&[
// (3, 1, 4) / 8, (1, 5, 9) / 15
[[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
// (2, 1, 7) / 10, (8, 2, 8) / 18
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
]
);
Ok(())
}
#[test]
fn softmax_numerical_stability() -> Result<()> {
let dev = &Device::Cpu;
let xs = Tensor::new(&[1234f32, 0.], dev)?;
let softmax = candle_nn::ops::softmax(&xs, 0)?;
assert_eq!(softmax.to_vec1::<f32>()?, &[1f32, 0.]);
Ok(())
}

View File

@ -17,7 +17,7 @@ impl LogitsProcessor {
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = self.temperature {
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?;
let prs: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
distr.sample(&mut self.rng) as u32

View File

@ -158,7 +158,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;

View File

@ -1,7 +1,7 @@
use crate::model::{Cache, Config, Llama};
use byteorder::{LittleEndian, ReadBytesExt};
use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
use candle_nn::VarBuilder;
use candle_nn::{ops::softmax, VarBuilder};
use rand::{distributions::Distribution, SeedableRng};
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
@ -88,7 +88,7 @@ impl LogitsProcessor {
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = self.temperature {
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let prs = softmax(&(&logits / temperature)?, D::Minus1)?;
let prs: Vec<f32> = prs.to_vec1()?;
let distr =
rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?;

View File

@ -200,7 +200,7 @@ impl MultiHeadAttention {
}
let w = {
let _timer = crate::Timer::new("qk::softmax");
qk.softmax(candle::D::Minus1)?
candle_nn::ops::softmax(&qk, candle::D::Minus1)?
};
let wv = {
let _timer = crate::Timer::new("wv::matmul");

View File

@ -1,7 +1,7 @@
use crate::model::{Config, Whisper};
use anyhow::Error as E;
use candle::{safetensors::Load, DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_nn::{ops::softmax, VarBuilder};
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
@ -127,9 +127,7 @@ impl Decoder {
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
no_speech_prob = logits
.get(0)?
.softmax(0)?
no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)?
.to_scalar::<f32>()? as f64;
}
@ -139,7 +137,7 @@ impl Decoder {
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = (&logits / t)?.softmax(0)?;
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(rng) as u32
@ -153,8 +151,7 @@ impl Decoder {
.unwrap()
};
tokens.push(next_token);
let prob = logits
.softmax(candle::D::Minus1)?
let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {