mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Add argsort. (#2132)
* Add the argsort cuda kernels. * CPU version of arg-sort. * Hook the cuda kernel + rework the cpu bits. * Add some dedicated test. * Working cuda kernel. * Metal kernel. * Metal adjustments. * Bugfix. * Use the fast rope in qwen. * Rework the expert selection in qwen.
This commit is contained in:
@ -6,5 +6,6 @@ pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
|
88
candle-kernels/src/sort.cu
Normal file
88
candle-kernels/src/sort.cu
Normal file
@ -0,0 +1,88 @@
|
||||
// Adapted from https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda/argsort.cu
|
||||
#define SORT_ORDER_ASC 1
|
||||
#define SORT_ORDER_DESC 0
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
template<typename T>
|
||||
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||
T tmp = a;
|
||||
a = b;
|
||||
b = tmp;
|
||||
}
|
||||
|
||||
template<int order, typename T>
|
||||
static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
|
||||
const T * x_row = x + row * ncols;
|
||||
extern __shared__ int dst_row[];
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
|
||||
#define ASORT_OP(TYPENAME, RUST_NAME) \
|
||||
extern "C" __global__ void asort_asc_##RUST_NAME( \
|
||||
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
|
||||
) { \
|
||||
k_argsort<SORT_ORDER_ASC>(x, dst, ncols, ncols_pad); \
|
||||
} \
|
||||
extern "C" __global__ void asort_desc_##RUST_NAME( \
|
||||
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
|
||||
) { \
|
||||
k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
ASORT_OP(__nv_bfloat16, bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
ASORT_OP(__half, f16)
|
||||
#endif
|
||||
|
||||
ASORT_OP(float, f32)
|
||||
ASORT_OP(double, f64)
|
||||
ASORT_OP(uint8_t, u8)
|
||||
ASORT_OP(uint32_t, u32)
|
||||
ASORT_OP(int64_t, i64)
|
Reference in New Issue
Block a user