mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Contiguous variant of the rope kernel. (#1929)
* Contiguous variant of the rope kernel. * Add the cuda kernel. * Metal kernel.
This commit is contained in:
@ -849,6 +849,49 @@ pub fn call_rope_i(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_rope(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
bh: usize,
|
||||
td: usize,
|
||||
d: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
cos_offset: usize,
|
||||
sin: &Buffer,
|
||||
sin_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
bh,
|
||||
td,
|
||||
d,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2);
|
||||
encoder.use_resource(src, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_affine(
|
||||
device: &Device,
|
||||
|
@ -313,8 +313,8 @@ kernel void NAME(
|
||||
} \
|
||||
} \
|
||||
|
||||
#define ROPEI(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
#define ROPEI(FN_NAME, FN_NAME_I, TYPENAME) \
|
||||
kernel void FN_NAME_I( \
|
||||
constant size_t &bh, \
|
||||
constant size_t &td, \
|
||||
device const TYPENAME *src, \
|
||||
@ -332,6 +332,31 @@ kernel void FN_NAME( \
|
||||
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; \
|
||||
dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; \
|
||||
}\
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &bh, \
|
||||
constant size_t &td, \
|
||||
constant size_t &d, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint idx [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (2 * idx >= bh * td) { \
|
||||
return; \
|
||||
} \
|
||||
size_t i_bh = idx / (td / 2); \
|
||||
size_t i_td = idx - (td / 2) * i_bh; \
|
||||
size_t i_t = i_td / (d / 2); \
|
||||
size_t i_d = i_td - (d / 2) * i_t; \
|
||||
size_t i1 = i_bh * td + i_t * d + i_d; \
|
||||
size_t i2 = i1 + d / 2; \
|
||||
size_t i_cs = i_t * (d / 2) + i_d; \
|
||||
TYPENAME c = cos[i_cs]; \
|
||||
TYPENAME s = sin[i_cs]; \
|
||||
dst[i1] = src[i1] * c - src[i2] * s; \
|
||||
dst[i2] = src[i1] * s + src[i2] * c; \
|
||||
}\
|
||||
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
@ -361,8 +386,8 @@ SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
RMSNORM(rmsnorm_f32, float)
|
||||
RMSNORM(rmsnorm_f16, half)
|
||||
ROPEI(rope_i_f32, float)
|
||||
ROPEI(rope_i_f16, half)
|
||||
ROPEI(rope_f32, rope_i_f32, float)
|
||||
ROPEI(rope_f16, rope_i_f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
@ -381,5 +406,5 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
RMSNORM(rmsnorm_bf16, bfloat)
|
||||
ROPEI(rope_i_bf16, bfloat)
|
||||
ROPEI(rope_bf16, rope_i_bf16, bfloat)
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user