mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add argsort. (#2132)
* Add the argsort cuda kernels. * CPU version of arg-sort. * Hook the cuda kernel + rework the cpu bits. * Add some dedicated test. * Working cuda kernel. * Metal kernel. * Metal adjustments. * Bugfix. * Use the fast rope in qwen. * Rework the expert selection in qwen.
This commit is contained in:
@ -63,6 +63,7 @@ pub mod quantized;
|
||||
pub mod safetensors;
|
||||
pub mod scalar;
|
||||
pub mod shape;
|
||||
mod sort;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
|
@ -11,7 +11,7 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
mod device;
|
||||
pub use device::{DeviceId, MetalDevice};
|
||||
|
||||
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
|
||||
pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
|
||||
BufferOffset {
|
||||
buffer,
|
||||
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
|
||||
|
222
candle-core/src/sort.rs
Normal file
222
candle-core/src/sort.rs
Normal file
@ -0,0 +1,222 @@
|
||||
use crate::{Result, Tensor};
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct ArgSort {
|
||||
asc: bool,
|
||||
last_dim: usize,
|
||||
}
|
||||
|
||||
impl ArgSort {
|
||||
fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
|
||||
#[allow(clippy::uninit_vec)]
|
||||
// Safety: indexes are set later in the parallelized section.
|
||||
let mut sort_indexes = unsafe {
|
||||
let el_count = layout.shape().elem_count();
|
||||
let mut v = Vec::with_capacity(el_count);
|
||||
v.set_len(el_count);
|
||||
v
|
||||
};
|
||||
if self.asc {
|
||||
sort_indexes
|
||||
.par_chunks_exact_mut(self.last_dim)
|
||||
.zip(vs.par_chunks_exact(self.last_dim))
|
||||
.for_each(|(indexes, vs)| {
|
||||
indexes
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| *v = i as u32);
|
||||
indexes.sort_by(|&i, &j| {
|
||||
vs[i as usize]
|
||||
.partial_cmp(&vs[j as usize])
|
||||
.unwrap_or(std::cmp::Ordering::Greater)
|
||||
})
|
||||
});
|
||||
} else {
|
||||
sort_indexes
|
||||
.par_chunks_exact_mut(self.last_dim)
|
||||
.zip(vs.par_chunks_exact(self.last_dim))
|
||||
.for_each(|(indexes, vs)| {
|
||||
indexes
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| *v = i as u32);
|
||||
indexes.sort_by(|&j, &i| {
|
||||
vs[i as usize]
|
||||
.partial_cmp(&vs[j as usize])
|
||||
.unwrap_or(std::cmp::Ordering::Greater)
|
||||
})
|
||||
});
|
||||
}
|
||||
sort_indexes
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for ArgSort {
|
||||
fn name(&self) -> &'static str {
|
||||
"argsort"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
storage: &crate::CpuStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CpuStorage, crate::Shape)> {
|
||||
let sort_indexes = match storage {
|
||||
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
|
||||
};
|
||||
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
|
||||
Ok((sort_indexes, layout.shape().into()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &crate::CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
||||
use crate::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
|
||||
use crate::{CudaDevice, WithDType};
|
||||
|
||||
impl Map1Any for ArgSort {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &crate::Layout,
|
||||
_wrap: W,
|
||||
) -> Result<S> {
|
||||
let slice = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = if self.asc {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
||||
} else {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
||||
};
|
||||
let ncols = self.last_dim;
|
||||
let nrows = elem_count / ncols;
|
||||
let ncols_pad = next_power_of_2(ncols);
|
||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (1, nrows as u32, 1),
|
||||
block_dim: (ncols_pad as u32, 1, 1),
|
||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(S::U32(dst))
|
||||
}
|
||||
}
|
||||
|
||||
use crate::backend::BackendStorage;
|
||||
let dev = storage.device();
|
||||
let slice = self.map(&storage.slice, dev, layout)?;
|
||||
let dst = crate::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, crate::Shape)> {
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::DType;
|
||||
|
||||
let name = {
|
||||
if self.asc {
|
||||
match storage.dtype() {
|
||||
DType::BF16 => "asort_asc_bf16",
|
||||
DType::F16 => "asort_asc_f16",
|
||||
DType::F32 => "asort_asc_f32",
|
||||
DType::F64 => "asort_asc_f64",
|
||||
DType::U8 => "asort_asc_u8",
|
||||
DType::U32 => "asort_asc_u32",
|
||||
DType::I64 => "asort_asc_i64",
|
||||
}
|
||||
} else {
|
||||
match storage.dtype() {
|
||||
DType::BF16 => "asort_desc_bf16",
|
||||
DType::F16 => "asort_desc_f16",
|
||||
DType::F32 => "asort_desc_f32",
|
||||
DType::F64 => "asort_desc_f64",
|
||||
DType::U8 => "asort_desc_u8",
|
||||
DType::U32 => "asort_desc_u32",
|
||||
DType::I64 => "asort_desc_i64",
|
||||
}
|
||||
}
|
||||
};
|
||||
let device = storage.device();
|
||||
let kernels = device.kernels();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let el = layout.shape().elem_count();
|
||||
let ncols = self.last_dim;
|
||||
let nrows = el / ncols;
|
||||
let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
|
||||
let dst = device.new_buffer(el, DType::U32, "asort")?;
|
||||
let mut ncols_pad = 1;
|
||||
while ncols_pad < ncols {
|
||||
ncols_pad *= 2;
|
||||
}
|
||||
candle_metal_kernels::call_arg_sort(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
&name,
|
||||
nrows,
|
||||
ncols,
|
||||
ncols_pad,
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn next_power_of_2(x: usize) -> usize {
|
||||
let mut n = 1;
|
||||
while n < x {
|
||||
n *= 2
|
||||
}
|
||||
n
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Returns the indices that sort the tensor along the last dimension.
|
||||
///
|
||||
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||
/// comes to ties.
|
||||
pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
|
||||
if !self.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous {
|
||||
op: "arg_sort_last_dim",
|
||||
});
|
||||
}
|
||||
let last_dim = match self.dims().last() {
|
||||
None => crate::bail!("empty last-dim in arg-sort"),
|
||||
Some(last_dim) => *last_dim,
|
||||
};
|
||||
// No need for a backward pass for arg sort.
|
||||
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
|
||||
}
|
||||
}
|
@ -96,6 +96,22 @@ fn clamp(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn asort(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let indexes = tensor.arg_sort_last_dim(true)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||
);
|
||||
let indexes = tensor.arg_sort_last_dim(false)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -1151,6 +1167,7 @@ test_device!(
|
||||
);
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
||||
|
||||
|
@ -6,5 +6,6 @@ pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
|
88
candle-kernels/src/sort.cu
Normal file
88
candle-kernels/src/sort.cu
Normal file
@ -0,0 +1,88 @@
|
||||
// Adapted from https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda/argsort.cu
|
||||
#define SORT_ORDER_ASC 1
|
||||
#define SORT_ORDER_DESC 0
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
template<typename T>
|
||||
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||
T tmp = a;
|
||||
a = b;
|
||||
b = tmp;
|
||||
}
|
||||
|
||||
template<int order, typename T>
|
||||
static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
|
||||
const T * x_row = x + row * ncols;
|
||||
extern __shared__ int dst_row[];
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
|
||||
#define ASORT_OP(TYPENAME, RUST_NAME) \
|
||||
extern "C" __global__ void asort_asc_##RUST_NAME( \
|
||||
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
|
||||
) { \
|
||||
k_argsort<SORT_ORDER_ASC>(x, dst, ncols, ncols_pad); \
|
||||
} \
|
||||
extern "C" __global__ void asort_desc_##RUST_NAME( \
|
||||
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
|
||||
) { \
|
||||
k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
ASORT_OP(__nv_bfloat16, bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
ASORT_OP(__half, f16)
|
||||
#endif
|
||||
|
||||
ASORT_OP(float, f32)
|
||||
ASORT_OP(double, f64)
|
||||
ASORT_OP(uint8_t, u8)
|
||||
ASORT_OP(uint32_t, u32)
|
||||
ASORT_OP(int64_t, i64)
|
@ -21,6 +21,7 @@ const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
@ -35,6 +36,7 @@ pub enum Source {
|
||||
Conv,
|
||||
Random,
|
||||
Quantized,
|
||||
Sort,
|
||||
}
|
||||
|
||||
pub mod copy2d {
|
||||
@ -197,6 +199,7 @@ impl Kernels {
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Sort => SORT,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
@ -2048,5 +2051,42 @@ pub fn call_conv_transpose2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_arg_sort(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
ncols_pad: usize,
|
||||
src: BufferOffset,
|
||||
dst: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: nrows as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let thread_group_size = MTLSize {
|
||||
width: ncols_pad as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
@ -1,3 +1,4 @@
|
||||
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
97
candle-metal-kernels/src/sort.metal
Normal file
97
candle-metal-kernels/src/sort.metal
Normal file
@ -0,0 +1,97 @@
|
||||
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
||||
#define SORT_ASC 1
|
||||
#define SORT_DESC 0
|
||||
|
||||
template<int order, typename T>
|
||||
METAL_FUNC void argsort(
|
||||
device const T * x,
|
||||
device uint32_t * dst,
|
||||
constant int64_t & ncols,
|
||||
constant int64_t & ncols_pad,
|
||||
threadgroup uint32_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
||||
int col = tpitg[0];
|
||||
int row = tgpig[1];
|
||||
|
||||
if (col >= ncols_pad) return;
|
||||
|
||||
device const T * x_row = x + row * ncols;
|
||||
threadgroup uint32_t * dst_row = shared_values;
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols && (order == SORT_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
) {
|
||||
SWAP(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols && (order == SORT_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
) {
|
||||
SWAP(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
|
||||
#define ARGSORT(T, RUST_T) \
|
||||
kernel void asort_asc_##RUST_T( \
|
||||
device const T * x, \
|
||||
device uint32_t * dst, \
|
||||
constant int64_t & ncols, \
|
||||
constant int64_t & ncols_pad, \
|
||||
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
|
||||
uint3 tgpig[[threadgroup_position_in_grid]], \
|
||||
uint3 tpitg[[thread_position_in_threadgroup]] \
|
||||
) { \
|
||||
argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
|
||||
} \
|
||||
kernel void asort_desc_##RUST_T( \
|
||||
device const T * x, \
|
||||
device uint32_t * dst, \
|
||||
constant int64_t & ncols, \
|
||||
constant int64_t & ncols_pad, \
|
||||
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
|
||||
uint3 tgpig[[threadgroup_position_in_grid]], \
|
||||
uint3 tpitg[[thread_position_in_threadgroup]] \
|
||||
) { \
|
||||
argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
|
||||
} \
|
||||
|
||||
ARGSORT(float, f32)
|
||||
ARGSORT(half, f16)
|
||||
ARGSORT(uint8_t, u8)
|
||||
ARGSORT(uint32_t, u32)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
ARGSORT(int64_t, i64)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
ARGSORT(bfloat, bf16)
|
||||
#endif
|
@ -27,13 +27,6 @@ struct RotaryEmbedding {
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
@ -48,7 +41,6 @@ impl RotaryEmbedding {
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
@ -64,10 +56,8 @@ impl RotaryEmbedding {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
@ -33,13 +33,6 @@ struct RotaryEmbedding {
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
@ -54,7 +47,6 @@ impl RotaryEmbedding {
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
@ -70,10 +62,8 @@ impl RotaryEmbedding {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
@ -259,30 +249,28 @@ impl Module for SparseMoeBlock {
|
||||
|
||||
// In order to extract topk, we extract the data from the tensor and manipulate it
|
||||
// directly. Maybe we will want to use some custom ops instead at some point.
|
||||
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||
let experts_per_tok = routing_weights
|
||||
.arg_sort_last_dim(false)?
|
||||
.narrow(D::Minus1, 0, self.num_experts_per_tok)?
|
||||
.contiguous()?;
|
||||
let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
|
||||
|
||||
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
// top_x contains the row indexes to evaluate for each expert.
|
||||
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||
let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
|
||||
let mut top_x = vec![vec![]; self.experts.len()];
|
||||
let mut selected_experts = vec![vec![]; self.experts.len()];
|
||||
for (row_idx, rw) in routing_weights.iter().enumerate() {
|
||||
let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
|
||||
dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
|
||||
let mut sum_routing_weights = 0f32;
|
||||
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = rw[expert_idx];
|
||||
sum_routing_weights += routing_weight;
|
||||
top_x[expert_idx].push(row_idx as u32);
|
||||
}
|
||||
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = if self.norm_topk_prob {
|
||||
rw[expert_idx] / sum_routing_weights
|
||||
} else {
|
||||
rw[expert_idx]
|
||||
};
|
||||
selected_experts[expert_idx].push(routing_weight)
|
||||
for (row_idx, (rw, expert_idxs)) in routing_weights
|
||||
.iter()
|
||||
.zip(experts_per_tok.iter())
|
||||
.enumerate()
|
||||
{
|
||||
let sum_rw = rw.iter().sum::<f32>();
|
||||
for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
|
||||
top_x[expert_idx as usize].push(row_idx as u32);
|
||||
let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
|
||||
selected_experts[expert_idx as usize].push(rw)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user