mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
More accelerate optimizations (#427)
* Add more tracing to the whisper example. * Support accelerate in more examples. * Use accelerate for pointwise functions. * Use accelerate for binary operations too. * Bugfix for binary operation: use the rhs before the lhs.
This commit is contained in:
@ -1,5 +1,8 @@
|
||||
// TODO: Add an offline mode.
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
|
@ -1,9 +1,11 @@
|
||||
// https://github.com/openai/whisper/blob/main/whisper/model.py/rgs
|
||||
// TODO:
|
||||
// - kv-cache support?
|
||||
// - Batch size greater than 1.
|
||||
// - More token filters (SuppressBlanks, ApplyTimestampRules).
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle::{Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
@ -105,12 +105,16 @@ struct MultiHeadAttention {
|
||||
out: Linear,
|
||||
n_head: usize,
|
||||
span: tracing::Span,
|
||||
softmax_span: tracing::Span,
|
||||
matmul_span: tracing::Span,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
|
||||
let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
|
||||
let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
|
||||
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
||||
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
||||
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
||||
@ -122,6 +126,8 @@ impl MultiHeadAttention {
|
||||
out,
|
||||
n_head,
|
||||
span,
|
||||
softmax_span,
|
||||
matmul_span,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
@ -178,13 +184,24 @@ impl MultiHeadAttention {
|
||||
let q = (self.reshape_head(q)? * scale)?;
|
||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||
let v = self.reshape_head(v)?.contiguous()?;
|
||||
let mut qk = q.matmul(&k)?;
|
||||
let mut qk = {
|
||||
let _enter = self.matmul_span.enter();
|
||||
q.matmul(&k)?
|
||||
};
|
||||
if let Some(mask) = mask {
|
||||
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
|
||||
let mask = mask.i((0..n_ctx, 0..n_ctx))?;
|
||||
qk = qk.broadcast_add(&mask)?
|
||||
}
|
||||
let w = softmax(&qk, candle::D::Minus1)?;
|
||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
|
||||
let w = {
|
||||
let _enter = self.softmax_span.enter();
|
||||
softmax(&qk, candle::D::Minus1)?
|
||||
};
|
||||
let wv = {
|
||||
let _enter = self.matmul_span.enter();
|
||||
w.matmul(&v)?
|
||||
}
|
||||
.transpose(1, 2)?
|
||||
.flatten_from(2)?;
|
||||
Ok(wv)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user