mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Fixing the kernels + launches to make them faster.
Cool work by @ivarflakstad Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
#![allow(clippy::too_many_arguments)]
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
|
||||
Device, Function, Library, MTLSize,
|
||||
ComputePipelineState, Device, Function, Library, MTLSize,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
@ -15,6 +15,24 @@ const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
|
||||
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||
let size = length as u64;
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
||||
let count = (size + width - 1) / width;
|
||||
let thread_group_count = MTLSize {
|
||||
width: count,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
(thread_group_count, thread_group_size)
|
||||
}
|
||||
|
||||
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
|
||||
<P as EncoderParam>::set_param(encoder, position, data)
|
||||
}
|
||||
@ -257,19 +275,7 @@ pub fn call_unary_contiguous(
|
||||
|
||||
set_params!(encoder, (length, input, output));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
@ -314,17 +320,7 @@ pub fn call_unary_strided(
|
||||
);
|
||||
|
||||
let width: usize = shape.iter().product();
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -358,18 +354,7 @@ pub fn call_binary_contiguous(
|
||||
|
||||
set_params!(encoder, (length, left, right, output));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -421,17 +406,7 @@ pub fn call_binary_strided(
|
||||
)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -464,18 +439,7 @@ pub fn call_cast_contiguous(
|
||||
|
||||
set_params!(encoder, (length, input, output));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -608,19 +572,7 @@ pub fn call_affine(
|
||||
|
||||
set_params!(encoder, (size, mul, add, input, output));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
@ -672,18 +624,7 @@ pub fn call_where_cond_strided(
|
||||
)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -730,19 +671,9 @@ pub fn call_index_select(
|
||||
)
|
||||
);
|
||||
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
|
||||
let grid_size = MTLSize {
|
||||
width: (dst_el as u64 + width - 1) / width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user