mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Adding tons of profiling and removing the metal allocation (still slow).
This commit is contained in:
@ -79,6 +79,8 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
impl LayerWeights {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cos");
|
||||
let _enter = span.enter();
|
||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = self
|
||||
.cos
|
||||
@ -88,21 +90,37 @@ impl LayerWeights {
|
||||
.sin
|
||||
.narrow(0, index_pos, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
drop(_enter);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad");
|
||||
let _enter = span.enter();
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
drop(_enter);
|
||||
// This mimics the llama.cpp behavior.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
|
||||
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
||||
// The resulting y0 and y1 are also interleaved with:
|
||||
// y0 = x0*cos - x1*sin
|
||||
// y1 = x0*sin + x1*cos
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-reshape");
|
||||
let _enter = span.enter();
|
||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
drop(_enter);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad-mul");
|
||||
let _enter = span.enter();
|
||||
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
drop(_enter);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cat");
|
||||
let _enter = span.enter();
|
||||
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
||||
drop(_enter);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-flatten");
|
||||
let _enter = span.enter();
|
||||
let rope = rope.flatten_from(D::Minus2)?;
|
||||
drop(_enter);
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user