mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d. * Move the cat operations. * Avoid transpositions in cat. * Bugfix. * Bugfix for the cuda kernel. * Add a benchmark. * Add more testing. * Test fix. * Faster kernel. * Add the missing kernel. * Tweak the test. * Add a metal kernel. * Fix for the metal kernel. * Get the tests to pass on metal. * Also use this opportunity to fix the metal kernel for ELU. * Add some bf16 kernels. * Clippy fixes.
This commit is contained in:
@ -127,6 +127,16 @@ pub enum Source {
|
||||
Quantized,
|
||||
}
|
||||
|
||||
pub mod copy2d {
|
||||
pub struct Kernel(pub &'static str);
|
||||
pub const FLOAT: Kernel = Kernel("copy2d_f32");
|
||||
pub const HALF: Kernel = Kernel("copy2d_f16");
|
||||
pub const BFLOAT: Kernel = Kernel("copy2d_bf16");
|
||||
pub const I64: Kernel = Kernel("copy2d_i64");
|
||||
pub const U32: Kernel = Kernel("copy2d_u32");
|
||||
pub const U8: Kernel = Kernel("copy2d_u8");
|
||||
}
|
||||
|
||||
macro_rules! ops{
|
||||
($($name:ident),+) => {
|
||||
|
||||
@ -365,6 +375,46 @@ pub fn call_unary_contiguous(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_copy2d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: copy2d::Kernel,
|
||||
input: &Buffer,
|
||||
output: &Buffer,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o_in_bytes: usize,
|
||||
dst_o_in_bytes: usize,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
d1,
|
||||
d2,
|
||||
src_s,
|
||||
dst_s,
|
||||
(input, src_o_in_bytes),
|
||||
(output, dst_o_in_bytes)
|
||||
)
|
||||
);
|
||||
|
||||
let width: usize = d1 * d2;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_strided(
|
||||
device: &Device,
|
||||
|
Reference in New Issue
Block a user