Use softmax-last-dim in whisper. (#810)

This commit is contained in:
Laurent Mazare
2023-09-11 11:05:05 +01:00
committed by GitHub
parent df712ecf64
commit 84ee870efd

View File

@ -1,5 +1,5 @@
use candle::{Device, IndexOp, Result, Tensor, D}; use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
use serde::Deserialize; use serde::Deserialize;
// The names in comments correspond to the original implementation: // The names in comments correspond to the original implementation:
@ -166,7 +166,7 @@ impl MultiHeadAttention {
} }
let w = { let w = {
let _enter = self.softmax_span.enter(); let _enter = self.softmax_span.enter();
softmax(&qk, D::Minus1)? candle_nn::ops::softmax_last_dim(&qk)?
}; };
let wv = { let wv = {
let _enter = self.matmul_span.enter(); let _enter = self.matmul_span.enter();