Add the rope THD kernel. (#2014)

* Add the rope THD kernel.

* Cuda kernel for rope-thd.

* Add the metal kernels.

* Add a dedicated test.
This commit is contained in:
Laurent Mazare
2024-04-05 08:32:58 +02:00
committed by GitHub
parent ace282e5c2
commit 2ac302a5d1
6 changed files with 400 additions and 31 deletions

View File

@ -849,6 +849,51 @@ pub fn call_rope_i(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope_thd(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
b: usize,
t: usize,
h: usize,
d: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
cos_offset: usize,
sin: &Buffer,
sin_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
b,
t,
h,
d,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (b * t * h * d) / 2);
encoder.use_resource(src, metal::MTLResourceUsage::Read);
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope(
device: &Device,