Add argsort. (#2132)

* Add the argsort cuda kernels.

* CPU version of arg-sort.

* Hook the cuda kernel + rework the cpu bits.

* Add some dedicated test.

* Working cuda kernel.

* Metal kernel.

* Metal adjustments.

* Bugfix.

* Use the fast rope in qwen.

* Rework the expert selection in qwen.
This commit is contained in:
Laurent Mazare
2024-04-27 20:17:35 +02:00
committed by GitHub
parent 6cf82fd7a3
commit 96a48e5cc4
11 changed files with 489 additions and 44 deletions

View File

@ -21,6 +21,7 @@ const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
const SORT: &str = include_str!("sort.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
@ -35,6 +36,7 @@ pub enum Source {
Conv,
Random,
Quantized,
Sort,
}
pub mod copy2d {
@ -197,6 +199,7 @@ impl Kernels {
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Sort => SORT,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -2048,5 +2051,42 @@ pub fn call_conv_transpose2d(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_arg_sort(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
nrows: usize,
ncols: usize,
ncols_pad: usize,
src: BufferOffset,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
let thread_group_count = MTLSize {
width: 1,
height: nrows as u64,
depth: 1,
};
let thread_group_size = MTLSize {
width: ncols_pad as u64,
height: 1,
depth: 1,
};
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests;

View File

@ -1,3 +1,4 @@
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
#include <metal_stdlib>
using namespace metal;

View File

@ -0,0 +1,97 @@
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
#include <metal_stdlib>
using namespace metal;
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
#define SORT_ASC 1
#define SORT_DESC 0
template<int order, typename T>
METAL_FUNC void argsort(
device const T * x,
device uint32_t * dst,
constant int64_t & ncols,
constant int64_t & ncols_pad,
threadgroup uint32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
int col = tpitg[0];
int row = tgpig[1];
if (col >= ncols_pad) return;
device const T * x_row = x + row * ncols;
threadgroup uint32_t * dst_row = shared_values;
// initialize indices
dst_row[col] = col;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == SORT_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
SWAP(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == SORT_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
SWAP(dst_row[col], dst_row[ixj]);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
#define ARGSORT(T, RUST_T) \
kernel void asort_asc_##RUST_T( \
device const T * x, \
device uint32_t * dst, \
constant int64_t & ncols, \
constant int64_t & ncols_pad, \
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
uint3 tgpig[[threadgroup_position_in_grid]], \
uint3 tpitg[[thread_position_in_threadgroup]] \
) { \
argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
} \
kernel void asort_desc_##RUST_T( \
device const T * x, \
device uint32_t * dst, \
constant int64_t & ncols, \
constant int64_t & ncols_pad, \
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
uint3 tgpig[[threadgroup_position_in_grid]], \
uint3 tpitg[[thread_position_in_threadgroup]] \
) { \
argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
} \
ARGSORT(float, f32)
ARGSORT(half, f16)
ARGSORT(uint8_t, u8)
ARGSORT(uint32_t, u32)
#if __METAL_VERSION__ >= 220
ARGSORT(int64_t, i64)
#endif
#if defined(__HAVE_BFLOAT__)
ARGSORT(bfloat, bf16)
#endif