From 196765e995f7f4bd3b9610a22f8ef5b009437a4e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 25 Mar 2024 23:26:05 +0100 Subject: [PATCH] Use the new rope kernel in mistral. (#1937) * Use the new rope kernel in mistral. * Compute the cos and sin with full precision. * Bugfix. --- candle-kernels/src/reduce.cu | 4 ++-- candle-transformers/src/models/mistral.rs | 20 ++++++------------- .../src/models/quantized_mistral.rs | 20 ++++++------------- 3 files changed, 14 insertions(+), 30 deletions(-) diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 48bbcd83..2af81c42 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -150,7 +150,7 @@ __device__ void softmax(const T * x, T * dst, const int ncols) { template __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (2 * idx > bh * td) return; + if (2 * idx >= bh * td) return; uint32_t rope_idx = idx % (td / 2); T c = cos[rope_idx]; @@ -163,7 +163,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons template __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (2 * idx > bh * td) return; + if (2 * idx >= bh * td) return; uint32_t i_bh = idx / (td / 2); uint32_t i_td = idx - (td / 2) * i_bh; diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 0e6200f5..d899c712 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -88,13 +88,6 @@ struct RotaryEmbedding { cos: Tensor, } -fn rotate_half(xs: &Tensor) -> Result { - let last_dim = xs.dim(D::Minus1)?; - let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; - let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; - Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) -} - impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { let rope_theta = cfg.rope_theta as f32; @@ -110,7 +103,6 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -126,10 +118,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; - let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?; Ok((q_embed, k_embed)) } } @@ -252,10 +242,12 @@ impl Attention { let query_states = query_states .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let key_states = key_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let value_states = value_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index 2c5b7f74..e37785de 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -12,13 +12,6 @@ struct RotaryEmbedding { cos: Tensor, } -fn rotate_half(xs: &Tensor) -> Result { - let last_dim = xs.dim(D::Minus1)?; - let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; - let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; - Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) -} - impl RotaryEmbedding { fn new(cfg: &Config, dev: &Device) -> Result { let rope_theta = cfg.rope_theta as f32; @@ -34,7 +27,6 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -50,10 +42,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; - let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?; Ok((q_embed, k_embed)) } } @@ -158,10 +148,12 @@ impl Attention { let query_states = query_states .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let key_states = key_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let value_states = value_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?;