mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add argsort. (#2132)
* Add the argsort cuda kernels. * CPU version of arg-sort. * Hook the cuda kernel + rework the cpu bits. * Add some dedicated test. * Working cuda kernel. * Metal kernel. * Metal adjustments. * Bugfix. * Use the fast rope in qwen. * Rework the expert selection in qwen.
This commit is contained in:
@ -21,6 +21,7 @@ const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
@ -35,6 +36,7 @@ pub enum Source {
|
||||
Conv,
|
||||
Random,
|
||||
Quantized,
|
||||
Sort,
|
||||
}
|
||||
|
||||
pub mod copy2d {
|
||||
@ -197,6 +199,7 @@ impl Kernels {
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Sort => SORT,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
@ -2048,5 +2051,42 @@ pub fn call_conv_transpose2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_arg_sort(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
ncols_pad: usize,
|
||||
src: BufferOffset,
|
||||
dst: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: nrows as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let thread_group_size = MTLSize {
|
||||
width: ncols_pad as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
@ -1,3 +1,4 @@
|
||||
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
97
candle-metal-kernels/src/sort.metal
Normal file
97
candle-metal-kernels/src/sort.metal
Normal file
@ -0,0 +1,97 @@
|
||||
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
||||
#define SORT_ASC 1
|
||||
#define SORT_DESC 0
|
||||
|
||||
template<int order, typename T>
|
||||
METAL_FUNC void argsort(
|
||||
device const T * x,
|
||||
device uint32_t * dst,
|
||||
constant int64_t & ncols,
|
||||
constant int64_t & ncols_pad,
|
||||
threadgroup uint32_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
||||
int col = tpitg[0];
|
||||
int row = tgpig[1];
|
||||
|
||||
if (col >= ncols_pad) return;
|
||||
|
||||
device const T * x_row = x + row * ncols;
|
||||
threadgroup uint32_t * dst_row = shared_values;
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols && (order == SORT_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
) {
|
||||
SWAP(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols && (order == SORT_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
) {
|
||||
SWAP(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
|
||||
#define ARGSORT(T, RUST_T) \
|
||||
kernel void asort_asc_##RUST_T( \
|
||||
device const T * x, \
|
||||
device uint32_t * dst, \
|
||||
constant int64_t & ncols, \
|
||||
constant int64_t & ncols_pad, \
|
||||
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
|
||||
uint3 tgpig[[threadgroup_position_in_grid]], \
|
||||
uint3 tpitg[[thread_position_in_threadgroup]] \
|
||||
) { \
|
||||
argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
|
||||
} \
|
||||
kernel void asort_desc_##RUST_T( \
|
||||
device const T * x, \
|
||||
device uint32_t * dst, \
|
||||
constant int64_t & ncols, \
|
||||
constant int64_t & ncols_pad, \
|
||||
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
|
||||
uint3 tgpig[[threadgroup_position_in_grid]], \
|
||||
uint3 tpitg[[thread_position_in_threadgroup]] \
|
||||
) { \
|
||||
argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
|
||||
} \
|
||||
|
||||
ARGSORT(float, f32)
|
||||
ARGSORT(half, f16)
|
||||
ARGSORT(uint8_t, u8)
|
||||
ARGSORT(uint32_t, u32)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
ARGSORT(int64_t, i64)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
ARGSORT(bfloat, bf16)
|
||||
#endif
|
Reference in New Issue
Block a user