mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add the rope THD kernel. (#2014)
* Add the rope THD kernel. * Cuda kernel for rope-thd. * Add the metal kernels. * Add a dedicated test.
This commit is contained in:
@ -849,6 +849,51 @@ pub fn call_rope_i(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_rope_thd(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
b: usize,
|
||||
t: usize,
|
||||
h: usize,
|
||||
d: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
cos_offset: usize,
|
||||
sin: &Buffer,
|
||||
sin_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
b,
|
||||
t,
|
||||
h,
|
||||
d,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (b * t * h * d) / 2);
|
||||
encoder.use_resource(src, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_rope(
|
||||
device: &Device,
|
||||
|
Reference in New Issue
Block a user