Fast kernels for rotary embeddings. (#1928)

* Fast kernels for rotary embeddings.

* Add a test for the fast CPU kernel.

* Rope cuda bindings.

* Cuda kernel.

* Metal kernel (part 1).

* Cuda kernels.

* Finish the metal kernel.

* Use the new kernels in the quantized example.

* Fix warning.
This commit is contained in:
Laurent Mazare
2024-03-24 22:48:52 +01:00
committed by GitHub
parent cf7d7fcf2f
commit 1b98f84a2b
8 changed files with 375 additions and 26 deletions

View File

@ -3,7 +3,7 @@ use std::collections::HashMap;
use crate::quantized_nn::RmsNorm;
use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096;
@ -154,31 +154,10 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Ten
impl LayerWeights {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let cos = self
.cos
.narrow(0, index_pos, seq_len)?
.reshape((seq_len, n_embd / 2, 1))?;
let sin = self
.sin
.narrow(0, index_pos, seq_len)?
.reshape((seq_len, n_embd / 2, 1))?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
// This mimics the llama.cpp behavior.
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
// The resulting y0 and y1 are also interleaved with:
// y0 = x0*cos - x1*sin
// y1 = x0*sin + x1*cos
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
let rope = rope.flatten_from(D::Minus2)?;
Ok(rope)
let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?;
let sin = self.sin.narrow(0, index_pos, seq_len)?;
candle_nn::rotary_emb::rope_i(&x.contiguous()?, &cos, &sin)
}
fn forward_attn(