Optimize copy-2d for metal. (#2024)

* Optimize copy-2d for metal.

* Add a hacky stopping rule for moondream.
This commit is contained in:
Laurent Mazare
2024-04-07 12:34:16 +02:00
committed by GitHub
parent 33c9b66554
commit 7f354473cf
3 changed files with 58 additions and 21 deletions

View File

@ -40,6 +40,44 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL
(thread_group_count, thread_group_size)
}
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
let mut sum = 0u64;
loop {
let presum = sum;
// Check all the pows
if dim0 >= (1 << (pows0 + 1)) {
pows0 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim1 >= (1 << (pows1 + 1)) {
pows1 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim2 >= (1 << (pows2 + 1)) {
pows2 += 1;
sum += 1;
}
if sum == presum || sum == 10 {
break;
}
}
MTLSize {
width: 1 << pows0,
height: 1 << pows1,
depth: 1 << pows2,
}
}
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
@ -396,21 +434,24 @@ pub fn call_copy2d(
set_params!(
encoder,
(
d1,
d2,
src_s,
dst_s,
d1 as i64,
d2 as i64,
src_s as i64,
dst_s as i64,
(input, src_o_in_bytes),
(output, dst_o_in_bytes)
)
);
let width: usize = d1 * d2;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
let grid_dims = MTLSize {
width: d1 as u64,
height: d2 as u64,
depth: 1,
};
let group_dims = get_block_dims(d1 as u64, d2 as u64, 1);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.dispatch_threads(grid_dims, group_dims);
encoder.end_encoding();
Ok(())
}