More accelerate optimizations (#427)

* Add more tracing to the whisper example.

* Support accelerate in more examples.

* Use accelerate for pointwise functions.

* Use accelerate for binary operations too.

* Bugfix for binary operation: use the rhs before the lhs.
This commit is contained in:
Laurent Mazare
2023-08-13 13:53:34 +02:00
committed by GitHub
parent 60cd1551ca
commit 9aca398a4f
9 changed files with 320 additions and 11 deletions

View File

@ -1,9 +1,11 @@
// https://github.com/openai/whisper/blob/main/whisper/model.py/rgs
// TODO:
// - kv-cache support?
// - Batch size greater than 1.
// - More token filters (SuppressBlanks, ApplyTimestampRules).
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

View File

@ -1,4 +1,4 @@
use candle::{Device, Result, Tensor};
use candle::{Device, IndexOp, Result, Tensor};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
use serde::Deserialize;
@ -105,12 +105,16 @@ struct MultiHeadAttention {
out: Linear,
n_head: usize,
span: tracing::Span,
softmax_span: tracing::Span,
matmul_span: tracing::Span,
kv_cache: Option<(Tensor, Tensor)>,
}
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
@ -122,6 +126,8 @@ impl MultiHeadAttention {
out,
n_head,
span,
softmax_span,
matmul_span,
kv_cache: None,
})
}
@ -178,13 +184,24 @@ impl MultiHeadAttention {
let q = (self.reshape_head(q)? * scale)?;
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
let v = self.reshape_head(v)?.contiguous()?;
let mut qk = q.matmul(&k)?;
let mut qk = {
let _enter = self.matmul_span.enter();
q.matmul(&k)?
};
if let Some(mask) = mask {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
let mask = mask.i((0..n_ctx, 0..n_ctx))?;
qk = qk.broadcast_add(&mask)?
}
let w = softmax(&qk, candle::D::Minus1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
let w = {
let _enter = self.softmax_span.enter();
softmax(&qk, candle::D::Minus1)?
};
let wv = {
let _enter = self.matmul_span.enter();
w.matmul(&v)?
}
.transpose(1, 2)?
.flatten_from(2)?;
Ok(wv)
}
}