Add argsort. (#2132)

* Add the argsort cuda kernels.

* CPU version of arg-sort.

* Hook the cuda kernel + rework the cpu bits.

* Add some dedicated test.

* Working cuda kernel.

* Metal kernel.

* Metal adjustments.

* Bugfix.

* Use the fast rope in qwen.

* Rework the expert selection in qwen.
This commit is contained in:
Laurent Mazare
2024-04-27 20:17:35 +02:00
committed by GitHub
parent 6cf82fd7a3
commit 96a48e5cc4
11 changed files with 489 additions and 44 deletions

View File

@ -63,6 +63,7 @@ pub mod quantized;
pub mod safetensors;
pub mod scalar;
pub mod shape;
mod sort;
mod storage;
mod strided_index;
mod tensor;

View File

@ -11,7 +11,7 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
BufferOffset {
buffer,
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),

222
candle-core/src/sort.rs Normal file
View File

@ -0,0 +1,222 @@
use crate::{Result, Tensor};
use rayon::prelude::*;
#[derive(Debug, Clone, Copy)]
struct ArgSort {
asc: bool,
last_dim: usize,
}
impl ArgSort {
fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
#[allow(clippy::uninit_vec)]
// Safety: indexes are set later in the parallelized section.
let mut sort_indexes = unsafe {
let el_count = layout.shape().elem_count();
let mut v = Vec::with_capacity(el_count);
v.set_len(el_count);
v
};
if self.asc {
sort_indexes
.par_chunks_exact_mut(self.last_dim)
.zip(vs.par_chunks_exact(self.last_dim))
.for_each(|(indexes, vs)| {
indexes
.iter_mut()
.enumerate()
.for_each(|(i, v)| *v = i as u32);
indexes.sort_by(|&i, &j| {
vs[i as usize]
.partial_cmp(&vs[j as usize])
.unwrap_or(std::cmp::Ordering::Greater)
})
});
} else {
sort_indexes
.par_chunks_exact_mut(self.last_dim)
.zip(vs.par_chunks_exact(self.last_dim))
.for_each(|(indexes, vs)| {
indexes
.iter_mut()
.enumerate()
.for_each(|(i, v)| *v = i as u32);
indexes.sort_by(|&j, &i| {
vs[i as usize]
.partial_cmp(&vs[j as usize])
.unwrap_or(std::cmp::Ordering::Greater)
})
});
}
sort_indexes
}
}
impl crate::CustomOp1 for ArgSort {
fn name(&self) -> &'static str {
"argsort"
}
fn cpu_fwd(
&self,
storage: &crate::CpuStorage,
layout: &crate::Layout,
) -> Result<(crate::CpuStorage, crate::Shape)> {
let sort_indexes = match storage {
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
};
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
Ok((sort_indexes, layout.shape().into()))
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, crate::Shape)> {
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
use crate::{CudaDevice, WithDType};
impl Map1Any for ArgSort {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &crate::Layout,
_wrap: W,
) -> Result<S> {
let slice = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let ncols_pad = next_power_of_2(ncols);
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
let cfg = LaunchConfig {
grid_dim: (1, nrows as u32, 1),
block_dim: (ncols_pad as u32, 1, 1),
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
};
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U32(dst))
}
}
use crate::backend::BackendStorage;
let dev = storage.device();
let slice = self.map(&storage.slice, dev, layout)?;
let dst = crate::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, layout.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, crate::Shape)> {
use crate::backend::BackendStorage;
use crate::DType;
let name = {
if self.asc {
match storage.dtype() {
DType::BF16 => "asort_asc_bf16",
DType::F16 => "asort_asc_f16",
DType::F32 => "asort_asc_f32",
DType::F64 => "asort_asc_f64",
DType::U8 => "asort_asc_u8",
DType::U32 => "asort_asc_u32",
DType::I64 => "asort_asc_i64",
}
} else {
match storage.dtype() {
DType::BF16 => "asort_desc_bf16",
DType::F16 => "asort_desc_f16",
DType::F32 => "asort_desc_f32",
DType::F64 => "asort_desc_f64",
DType::U8 => "asort_desc_u8",
DType::U32 => "asort_desc_u32",
DType::I64 => "asort_desc_i64",
}
}
};
let device = storage.device();
let kernels = device.kernels();
let command_buffer = device.command_buffer()?;
let el = layout.shape().elem_count();
let ncols = self.last_dim;
let nrows = el / ncols;
let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
let dst = device.new_buffer(el, DType::U32, "asort")?;
let mut ncols_pad = 1;
while ncols_pad < ncols {
ncols_pad *= 2;
}
candle_metal_kernels::call_arg_sort(
device.metal_device(),
&command_buffer,
kernels,
&name,
nrows,
ncols,
ncols_pad,
src,
&dst,
)
.map_err(crate::Error::wrap)?;
let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
Ok((dst, layout.shape().clone()))
}
}
#[allow(unused)]
fn next_power_of_2(x: usize) -> usize {
let mut n = 1;
while n < x {
n *= 2
}
n
}
impl Tensor {
/// Returns the indices that sort the tensor along the last dimension.
///
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
/// descending order. The sort is unstable so there is no guarantees on the final order when it
/// comes to ties.
pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
if !self.is_contiguous() {
return Err(crate::Error::RequiresContiguous {
op: "arg_sort_last_dim",
});
}
let last_dim = match self.dims().last() {
None => crate::bail!("empty last-dim in arg-sort"),
Some(last_dim) => *last_dim,
};
// No need for a backward pass for arg sort.
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
}
}

View File

@ -96,6 +96,22 @@ fn clamp(device: &Device) -> Result<()> {
Ok(())
}
fn asort(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
let indexes = tensor.arg_sort_last_dim(true)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
);
let indexes = tensor.arg_sort_last_dim(false)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
);
Ok(())
}
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
@ -1151,6 +1167,7 @@ test_device!(
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);

View File

@ -6,5 +6,6 @@ pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));

View File

@ -0,0 +1,88 @@
// Adapted from https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda/argsort.cu
#define SORT_ORDER_ASC 1
#define SORT_ORDER_DESC 0
#include "cuda_utils.cuh"
#include<stdint.h>
template<typename T>
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
T tmp = a;
a = b;
b = tmp;
}
template<int order, typename T>
static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
if (col >= ncols_pad) {
return;
}
const T * x_row = x + row * ncols;
extern __shared__ int dst_row[];
// initialize indices
dst_row[col] = col;
__syncthreads();
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
}
}
__syncthreads();
}
}
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
#define ASORT_OP(TYPENAME, RUST_NAME) \
extern "C" __global__ void asort_asc_##RUST_NAME( \
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
) { \
k_argsort<SORT_ORDER_ASC>(x, dst, ncols, ncols_pad); \
} \
extern "C" __global__ void asort_desc_##RUST_NAME( \
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
) { \
k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \
} \
#if __CUDA_ARCH__ >= 800
ASORT_OP(__nv_bfloat16, bf16)
#endif
#if __CUDA_ARCH__ >= 530
ASORT_OP(__half, f16)
#endif
ASORT_OP(float, f32)
ASORT_OP(double, f64)
ASORT_OP(uint8_t, u8)
ASORT_OP(uint32_t, u32)
ASORT_OP(int64_t, i64)

View File

@ -21,6 +21,7 @@ const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
const SORT: &str = include_str!("sort.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
@ -35,6 +36,7 @@ pub enum Source {
Conv,
Random,
Quantized,
Sort,
}
pub mod copy2d {
@ -197,6 +199,7 @@ impl Kernels {
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Sort => SORT,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -2048,5 +2051,42 @@ pub fn call_conv_transpose2d(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_arg_sort(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
nrows: usize,
ncols: usize,
ncols_pad: usize,
src: BufferOffset,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
let thread_group_count = MTLSize {
width: 1,
height: nrows as u64,
depth: 1,
};
let thread_group_size = MTLSize {
width: ncols_pad as u64,
height: 1,
depth: 1,
};
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests;

View File

@ -1,3 +1,4 @@
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
#include <metal_stdlib>
using namespace metal;

View File

@ -0,0 +1,97 @@
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
#include <metal_stdlib>
using namespace metal;
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
#define SORT_ASC 1
#define SORT_DESC 0
template<int order, typename T>
METAL_FUNC void argsort(
device const T * x,
device uint32_t * dst,
constant int64_t & ncols,
constant int64_t & ncols_pad,
threadgroup uint32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
int col = tpitg[0];
int row = tgpig[1];
if (col >= ncols_pad) return;
device const T * x_row = x + row * ncols;
threadgroup uint32_t * dst_row = shared_values;
// initialize indices
dst_row[col] = col;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == SORT_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
SWAP(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == SORT_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
SWAP(dst_row[col], dst_row[ixj]);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
#define ARGSORT(T, RUST_T) \
kernel void asort_asc_##RUST_T( \
device const T * x, \
device uint32_t * dst, \
constant int64_t & ncols, \
constant int64_t & ncols_pad, \
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
uint3 tgpig[[threadgroup_position_in_grid]], \
uint3 tpitg[[thread_position_in_threadgroup]] \
) { \
argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
} \
kernel void asort_desc_##RUST_T( \
device const T * x, \
device uint32_t * dst, \
constant int64_t & ncols, \
constant int64_t & ncols_pad, \
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
uint3 tgpig[[threadgroup_position_in_grid]], \
uint3 tpitg[[thread_position_in_threadgroup]] \
) { \
argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
} \
ARGSORT(float, f32)
ARGSORT(half, f16)
ARGSORT(uint8_t, u8)
ARGSORT(uint32_t, u32)
#if __METAL_VERSION__ >= 220
ARGSORT(int64_t, i64)
#endif
#if defined(__HAVE_BFLOAT__)
ARGSORT(bfloat, bf16)
#endif

View File

@ -27,13 +27,6 @@ struct RotaryEmbedding {
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
@ -48,7 +41,6 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
@ -64,10 +56,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}

View File

@ -33,13 +33,6 @@ struct RotaryEmbedding {
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
@ -54,7 +47,6 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
@ -70,10 +62,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
@ -259,30 +249,28 @@ impl Module for SparseMoeBlock {
// In order to extract topk, we extract the data from the tensor and manipulate it
// directly. Maybe we will want to use some custom ops instead at some point.
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
let experts_per_tok = routing_weights
.arg_sort_last_dim(false)?
.narrow(D::Minus1, 0, self.num_experts_per_tok)?
.contiguous()?;
let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
// top_x contains the row indexes to evaluate for each expert.
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
let mut top_x = vec![vec![]; self.experts.len()];
let mut selected_experts = vec![vec![]; self.experts.len()];
for (row_idx, rw) in routing_weights.iter().enumerate() {
let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
let mut sum_routing_weights = 0f32;
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
let expert_idx = expert_idx as usize;
let routing_weight = rw[expert_idx];
sum_routing_weights += routing_weight;
top_x[expert_idx].push(row_idx as u32);
}
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
let expert_idx = expert_idx as usize;
let routing_weight = if self.norm_topk_prob {
rw[expert_idx] / sum_routing_weights
} else {
rw[expert_idx]
};
selected_experts[expert_idx].push(routing_weight)
for (row_idx, (rw, expert_idxs)) in routing_weights
.iter()
.zip(experts_per_tok.iter())
.enumerate()
{
let sum_rw = rw.iter().sum::<f32>();
for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
top_x[expert_idx as usize].push(row_idx as u32);
let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
selected_experts[expert_idx as usize].push(rw)
}
}