Contiguous variant of the rope kernel. (#1929)

* Contiguous variant of the rope kernel.

* Add the cuda kernel.

* Metal kernel.
This commit is contained in:
Laurent Mazare
2024-03-25 09:11:20 +01:00
committed by GitHub
parent 1b98f84a2b
commit e7f8e72588
5 changed files with 389 additions and 13 deletions

View File

@ -849,6 +849,49 @@ pub fn call_rope_i(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
bh: usize,
td: usize,
d: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
cos_offset: usize,
sin: &Buffer,
sin_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
bh,
td,
d,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2);
encoder.use_resource(src, metal::MTLResourceUsage::Read);
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_affine(
device: &Device,