mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add argsort. (#2132)
* Add the argsort cuda kernels. * CPU version of arg-sort. * Hook the cuda kernel + rework the cpu bits. * Add some dedicated test. * Working cuda kernel. * Metal kernel. * Metal adjustments. * Bugfix. * Use the fast rope in qwen. * Rework the expert selection in qwen.
This commit is contained in:
@ -27,13 +27,6 @@ struct RotaryEmbedding {
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
@ -48,7 +41,6 @@ impl RotaryEmbedding {
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
@ -64,10 +56,8 @@ impl RotaryEmbedding {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
@ -33,13 +33,6 @@ struct RotaryEmbedding {
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
@ -54,7 +47,6 @@ impl RotaryEmbedding {
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
@ -70,10 +62,8 @@ impl RotaryEmbedding {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
@ -259,30 +249,28 @@ impl Module for SparseMoeBlock {
|
||||
|
||||
// In order to extract topk, we extract the data from the tensor and manipulate it
|
||||
// directly. Maybe we will want to use some custom ops instead at some point.
|
||||
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||
let experts_per_tok = routing_weights
|
||||
.arg_sort_last_dim(false)?
|
||||
.narrow(D::Minus1, 0, self.num_experts_per_tok)?
|
||||
.contiguous()?;
|
||||
let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
|
||||
|
||||
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
// top_x contains the row indexes to evaluate for each expert.
|
||||
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||
let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
|
||||
let mut top_x = vec![vec![]; self.experts.len()];
|
||||
let mut selected_experts = vec![vec![]; self.experts.len()];
|
||||
for (row_idx, rw) in routing_weights.iter().enumerate() {
|
||||
let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
|
||||
dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
|
||||
let mut sum_routing_weights = 0f32;
|
||||
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = rw[expert_idx];
|
||||
sum_routing_weights += routing_weight;
|
||||
top_x[expert_idx].push(row_idx as u32);
|
||||
}
|
||||
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = if self.norm_topk_prob {
|
||||
rw[expert_idx] / sum_routing_weights
|
||||
} else {
|
||||
rw[expert_idx]
|
||||
};
|
||||
selected_experts[expert_idx].push(routing_weight)
|
||||
for (row_idx, (rw, expert_idxs)) in routing_weights
|
||||
.iter()
|
||||
.zip(experts_per_tok.iter())
|
||||
.enumerate()
|
||||
{
|
||||
let sum_rw = rw.iter().sum::<f32>();
|
||||
for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
|
||||
top_x[expert_idx as usize].push(row_idx as u32);
|
||||
let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
|
||||
selected_experts[expert_idx as usize].push(rw)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user