Use softmax-last-dim in whisper. (#810)

This commit is contained in:
Laurent Mazare
2023-09-11 11:05:05 +01:00
committed by GitHub
parent df712ecf64
commit 84ee870efd

View File

@ -1,5 +1,5 @@
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@ -166,7 +166,7 @@ impl MultiHeadAttention {
}
let w = {
let _enter = self.softmax_span.enter();
softmax(&qk, D::Minus1)?
candle_nn::ops::softmax_last_dim(&qk)?
};
let wv = {
let _enter = self.matmul_span.enter();