Add the rope THD kernel. (#2014)

* Add the rope THD kernel.

* Cuda kernel for rope-thd.

* Add the metal kernels.

* Add a dedicated test.
This commit is contained in:
Laurent Mazare
2024-04-05 08:32:58 +02:00
committed by GitHub
parent ace282e5c2
commit 2ac302a5d1
6 changed files with 400 additions and 31 deletions

View File

@ -177,30 +177,14 @@ impl RotaryEmbedding {
}
let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
let rotary_dim = rotary_dim * 2;
let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?;
let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?.contiguous()?;
let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?;
let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?;
let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?.contiguous()?;
let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?;
let q12 = q_rot.chunk(2, D::Minus1)?;
let k12 = k_rot.chunk(2, D::Minus1)?;
let (q1, q2) = (&q12[0], &q12[1]);
let (k1, k2) = (&k12[0], &k12[1]);
let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
let q_rot = Tensor::cat(
&[
(q1.broadcast_mul(&c)? - q2.broadcast_mul(&s)?)?,
(q1.broadcast_mul(&s)? + q2.broadcast_mul(&c)?)?,
],
D::Minus1,
)?;
let k_rot = Tensor::cat(
&[
(k1.broadcast_mul(&c)? - k2.broadcast_mul(&s)?)?,
(k1.broadcast_mul(&s)? + k2.broadcast_mul(&c)?)?,
],
D::Minus1,
)?;
let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
let q_rot = candle_nn::rotary_emb::rope_thd(&q_rot, &c, &s)?;
let k_rot = candle_nn::rotary_emb::rope_thd(&k_rot, &c, &s)?;
let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?;
let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?;
let v = qkv.i((.., .., 2))?;