mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Use softmax-last-dim in whisper. (#810)
This commit is contained in:
@ -1,5 +1,5 @@
|
|||||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
// The names in comments correspond to the original implementation:
|
// The names in comments correspond to the original implementation:
|
||||||
@ -166,7 +166,7 @@ impl MultiHeadAttention {
|
|||||||
}
|
}
|
||||||
let w = {
|
let w = {
|
||||||
let _enter = self.softmax_span.enter();
|
let _enter = self.softmax_span.enter();
|
||||||
softmax(&qk, D::Minus1)?
|
candle_nn::ops::softmax_last_dim(&qk)?
|
||||||
};
|
};
|
||||||
let wv = {
|
let wv = {
|
||||||
let _enter = self.matmul_span.enter();
|
let _enter = self.matmul_span.enter();
|
||||||
|
Reference in New Issue
Block a user