mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the rope THD kernel. (#2014)
* Add the rope THD kernel. * Cuda kernel for rope-thd. * Add the metal kernels. * Add a dedicated test.
This commit is contained in:
@ -177,30 +177,14 @@ impl RotaryEmbedding {
|
||||
}
|
||||
let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
|
||||
let rotary_dim = rotary_dim * 2;
|
||||
let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?;
|
||||
let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?.contiguous()?;
|
||||
let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?;
|
||||
let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?;
|
||||
let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?.contiguous()?;
|
||||
let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?;
|
||||
let q12 = q_rot.chunk(2, D::Minus1)?;
|
||||
let k12 = k_rot.chunk(2, D::Minus1)?;
|
||||
let (q1, q2) = (&q12[0], &q12[1]);
|
||||
let (k1, k2) = (&k12[0], &k12[1]);
|
||||
let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
||||
let q_rot = Tensor::cat(
|
||||
&[
|
||||
(q1.broadcast_mul(&c)? - q2.broadcast_mul(&s)?)?,
|
||||
(q1.broadcast_mul(&s)? + q2.broadcast_mul(&c)?)?,
|
||||
],
|
||||
D::Minus1,
|
||||
)?;
|
||||
let k_rot = Tensor::cat(
|
||||
&[
|
||||
(k1.broadcast_mul(&c)? - k2.broadcast_mul(&s)?)?,
|
||||
(k1.broadcast_mul(&s)? + k2.broadcast_mul(&c)?)?,
|
||||
],
|
||||
D::Minus1,
|
||||
)?;
|
||||
let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
|
||||
let q_rot = candle_nn::rotary_emb::rope_thd(&q_rot, &c, &s)?;
|
||||
let k_rot = candle_nn::rotary_emb::rope_thd(&k_rot, &c, &s)?;
|
||||
let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?;
|
||||
let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?;
|
||||
let v = qkv.i((.., .., 2))?;
|
||||
|
Reference in New Issue
Block a user