mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Use the new rope kernel in mistral. (#1937)
* Use the new rope kernel in mistral. * Compute the cos and sin with full precision. * Bugfix.
This commit is contained in:
@ -150,7 +150,7 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if (2 * idx > bh * td) return;
|
if (2 * idx >= bh * td) return;
|
||||||
|
|
||||||
uint32_t rope_idx = idx % (td / 2);
|
uint32_t rope_idx = idx % (td / 2);
|
||||||
T c = cos[rope_idx];
|
T c = cos[rope_idx];
|
||||||
@ -163,7 +163,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
|
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if (2 * idx > bh * td) return;
|
if (2 * idx >= bh * td) return;
|
||||||
|
|
||||||
uint32_t i_bh = idx / (td / 2);
|
uint32_t i_bh = idx / (td / 2);
|
||||||
uint32_t i_td = idx - (td / 2) * i_bh;
|
uint32_t i_td = idx - (td / 2) * i_bh;
|
||||||
|
@ -88,13 +88,6 @@ struct RotaryEmbedding {
|
|||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let last_dim = xs.dim(D::Minus1)?;
|
|
||||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
|
||||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
|
||||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
let rope_theta = cfg.rope_theta as f32;
|
let rope_theta = cfg.rope_theta as f32;
|
||||||
@ -110,7 +103,6 @@ impl RotaryEmbedding {
|
|||||||
.to_dtype(dtype)?
|
.to_dtype(dtype)?
|
||||||
.reshape((max_seq_len, 1))?;
|
.reshape((max_seq_len, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sin: freqs.sin()?,
|
sin: freqs.sin()?,
|
||||||
cos: freqs.cos()?,
|
cos: freqs.cos()?,
|
||||||
@ -126,10 +118,8 @@ impl RotaryEmbedding {
|
|||||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
|
||||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
|
||||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
|
||||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
|
||||||
Ok((q_embed, k_embed))
|
Ok((q_embed, k_embed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -252,10 +242,12 @@ impl Attention {
|
|||||||
|
|
||||||
let query_states = query_states
|
let query_states = query_states
|
||||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
let key_states = key_states
|
let key_states = key_states
|
||||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
let value_states = value_states
|
let value_states = value_states
|
||||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?;
|
||||||
|
@ -12,13 +12,6 @@ struct RotaryEmbedding {
|
|||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let last_dim = xs.dim(D::Minus1)?;
|
|
||||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
|
||||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
|
||||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
|
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
let rope_theta = cfg.rope_theta as f32;
|
let rope_theta = cfg.rope_theta as f32;
|
||||||
@ -34,7 +27,6 @@ impl RotaryEmbedding {
|
|||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.reshape((max_seq_len, 1))?;
|
.reshape((max_seq_len, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sin: freqs.sin()?,
|
sin: freqs.sin()?,
|
||||||
cos: freqs.cos()?,
|
cos: freqs.cos()?,
|
||||||
@ -50,10 +42,8 @@ impl RotaryEmbedding {
|
|||||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
|
||||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
|
||||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
|
||||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
|
||||||
Ok((q_embed, k_embed))
|
Ok((q_embed, k_embed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -158,10 +148,12 @@ impl Attention {
|
|||||||
|
|
||||||
let query_states = query_states
|
let query_states = query_states
|
||||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
let key_states = key_states
|
let key_states = key_states
|
||||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
let value_states = value_states
|
let value_states = value_states
|
||||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?;
|
||||||
|
Reference in New Issue
Block a user