Optimize the cat operation on contiguous tensors (#1855)

* Add a specialized kernel for copy2d.

* Move the cat operations.

* Avoid transpositions in cat.

* Bugfix.

* Bugfix for the cuda kernel.

* Add a benchmark.

* Add more testing.

* Test fix.

* Faster kernel.

* Add the missing kernel.

* Tweak the test.

* Add a metal kernel.

* Fix for the metal kernel.

* Get the tests to pass on metal.

* Also use this opportunity to fix the metal kernel for ELU.

* Add some bf16 kernels.

* Clippy fixes.
This commit is contained in:
Laurent Mazare
2024-03-17 10:49:13 +01:00
committed by GitHub
parent db8b24ae92
commit ce9fbc3682
19 changed files with 744 additions and 208 deletions

View File

@ -127,6 +127,16 @@ pub enum Source {
Quantized,
}
pub mod copy2d {
pub struct Kernel(pub &'static str);
pub const FLOAT: Kernel = Kernel("copy2d_f32");
pub const HALF: Kernel = Kernel("copy2d_f16");
pub const BFLOAT: Kernel = Kernel("copy2d_bf16");
pub const I64: Kernel = Kernel("copy2d_i64");
pub const U32: Kernel = Kernel("copy2d_u32");
pub const U8: Kernel = Kernel("copy2d_u8");
}
macro_rules! ops{
($($name:ident),+) => {
@ -365,6 +375,46 @@ pub fn call_unary_contiguous(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_copy2d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: copy2d::Kernel,
input: &Buffer,
output: &Buffer,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o_in_bytes: usize,
dst_o_in_bytes: usize,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
d1,
d2,
src_s,
dst_s,
(input, src_o_in_bytes),
(output, dst_o_in_bytes)
)
);
let width: usize = d1 * d2;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_strided(
device: &Device,