Rework the buffer offset logic for metal kernels (#2028)

* Move the metal kernels utils in a separate module.

* Use the BufferOffset for unary ops.

* Fix clippy lints.

* Use the new BufferOffset.

* Adapt the binary ops.

* Affine.

* More ops (powf, elu, cast).
This commit is contained in:
Laurent Mazare
2024-04-07 22:37:53 +02:00
committed by GitHub
parent 7f354473cf
commit c5fe4a7f89
4 changed files with 305 additions and 286 deletions

View File

@ -1,11 +1,15 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function,
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split};
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
@ -18,138 +22,6 @@ const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
/// actual total buffer length).
/// Then kernels can just do their op on their single point in the buffer.
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = (size + width - 1) / width;
let thread_group_count = MTLSize {
width: count,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
(thread_group_count, thread_group_size)
}
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
let mut sum = 0u64;
loop {
let presum = sum;
// Check all the pows
if dim0 >= (1 << (pows0 + 1)) {
pows0 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim1 >= (1 << (pows1 + 1)) {
pows1 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim2 >= (1 << (pows2 + 1)) {
pows2 += 1;
sum += 1;
}
if sum == presum || sum == 10 {
break;
}
}
MTLSize {
width: 1 << pows0,
height: 1 << pows1,
depth: 1 << pows2,
}
}
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
($type:ty) => {
impl EncoderParam for $type {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<$type>() as u64,
&data as *const $type as *const c_void,
);
}
}
};
}
primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u32);
primitive!(u64);
primitive!(f32);
impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of_val(data) as u64,
data.as_ptr() as *const c_void,
);
}
}
impl EncoderParam for &Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
impl EncoderParam for &mut Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&mut Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
macro_rules! set_params {
($encoder:ident, ($($param:expr),+)) => (
let mut _index = 0;
$(
set_param($encoder, _index, $param);
_index += 1;
)*
);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
@ -273,6 +145,12 @@ pub struct Kernels {
pipelines: RwLock<Pipelines>,
}
impl Default for Kernels {
fn default() -> Self {
Self::new()
}
}
impl Kernels {
pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new());
@ -396,17 +274,17 @@ pub fn call_unary_contiguous(
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
input: &Buffer,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output));
set_params!(encoder, (length, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -463,11 +341,9 @@ pub fn call_unary_strided(
kernels: &Kernels,
name: unary::strided::Kernel,
shape: &[usize],
input: &Buffer,
input: BufferOffset,
strides: &[usize],
offset: usize,
output: &Buffer,
output_offset: usize,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
@ -476,23 +352,13 @@ pub fn call_unary_strided(
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
set_params!(
encoder,
(
length,
num_dims,
shape,
strides,
(input, offset),
(output, output_offset)
)
);
set_params!(encoder, (length, num_dims, shape, strides, &input, &output));
let width: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
@ -505,8 +371,8 @@ pub fn call_binary_contiguous(
kernels: &Kernels,
kernel_name: binary::contiguous::Kernel,
length: usize,
left: &Buffer,
right: &Buffer,
left: BufferOffset,
right: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
@ -514,12 +380,12 @@ pub fn call_binary_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output));
set_params!(encoder, (length, &left, &right, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(left, metal::MTLResourceUsage::Read);
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -533,12 +399,10 @@ pub fn call_binary_strided(
kernels: &Kernels,
name: binary::strided::Kernel,
shape: &[usize],
left_input: &Buffer,
left_input: BufferOffset,
left_strides: &[usize],
left_offset: usize,
right_input: &Buffer,
right_input: BufferOffset,
right_strides: &[usize],
right_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
@ -558,16 +422,16 @@ pub fn call_binary_strided(
shape,
left_strides,
right_strides,
(left_input, left_offset),
(right_input, right_offset),
&left_input,
&right_input,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.use_resource(left_input, metal::MTLResourceUsage::Read);
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -581,8 +445,7 @@ pub fn call_cast_contiguous(
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
input: &Buffer,
input_offset: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
@ -590,10 +453,10 @@ pub fn call_cast_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, (input, input_offset), output));
set_params!(encoder, (length, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -607,9 +470,8 @@ pub fn call_cast_strided(
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
input: &Buffer,
input: BufferOffset,
input_strides: &[usize],
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
@ -621,25 +483,19 @@ pub fn call_cast_strided(
set_params!(
encoder,
(
length,
shape.len(),
shape,
input_strides,
(input, input_offset),
output
)
(length, shape.len(), shape, input_strides, &input, output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_reduce_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@ -687,6 +543,7 @@ pub fn call_reduce_contiguous(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_reduce_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@ -985,7 +842,7 @@ pub fn call_affine(
kernels: &Kernels,
name: &'static str,
size: usize,
input: &Buffer,
input: BufferOffset,
output: &Buffer,
mul: f32,
add: f32,
@ -995,10 +852,10 @@ pub fn call_affine(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, input, output));
set_params!(encoder, (size, mul, add, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1012,9 +869,8 @@ pub fn call_affine_strided(
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: &Buffer,
input: BufferOffset,
input_stride: &[usize],
input_offset: usize,
output: &Buffer,
mul: f32,
add: f32,
@ -1034,13 +890,13 @@ pub fn call_affine_strided(
input_stride,
mul,
add,
(input, input_offset),
&input,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1054,7 +910,7 @@ pub fn call_powf(
kernels: &Kernels,
name: &'static str,
size: usize,
input: &Buffer,
input: BufferOffset,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@ -1063,10 +919,10 @@ pub fn call_powf(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
set_params!(encoder, (size, mul, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1080,9 +936,8 @@ pub fn call_powf_strided(
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: &Buffer,
input: BufferOffset,
input_stride: &[usize],
input_offset: usize,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@ -1094,19 +949,11 @@ pub fn call_powf_strided(
set_params!(
encoder,
(
size,
shape.len(),
shape,
input_stride,
mul,
(input, input_offset),
output
)
(size, shape.len(), shape, input_stride, mul, &input, output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1120,7 +967,7 @@ pub fn call_elu(
kernels: &Kernels,
name: &'static str,
size: usize,
input: &Buffer,
input: BufferOffset,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@ -1129,10 +976,10 @@ pub fn call_elu(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
set_params!(encoder, (size, mul, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@ -1146,9 +993,8 @@ pub fn call_elu_strided(
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: &Buffer,
input: BufferOffset,
input_stride: &[usize],
input_offset: usize,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@ -1160,25 +1006,18 @@ pub fn call_elu_strided(
set_params!(
encoder,
(
size,
shape.len(),
shape,
input_stride,
mul,
(input, input_offset),
output
)
(size, shape.len(), shape, input_stride, mul, &input, output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_where_cond_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@ -1334,6 +1173,7 @@ pub fn call_gather(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_scatter_add(
device: &Device,
command_buffer: &CommandBufferRef,
@ -1384,6 +1224,7 @@ pub fn call_scatter_add(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_index_add(
device: &Device,
command_buffer: &CommandBufferRef,
@ -1910,6 +1751,7 @@ pub enum GgmlDType {
F32,
}
#[allow(clippy::too_many_arguments)]
pub fn call_quantized_matmul_t(
device: &Device,
command_buffer: &CommandBufferRef,
@ -1925,16 +1767,16 @@ pub fn call_quantized_matmul_t(
let ne00 = k as i64;
let ne01 = n as i64;
let ne02 = b as i64;
let ne03 = 1 as i64;
let ne03 = 1i64;
let nb00 = 0i64;
let nb01 = 0 as i64;
let nb02 = 0 as i64;
let nb01 = 0i64;
let nb02 = 0i64;
let ne10 = k as i64;
let ne11 = m as i64;
let ne12 = b as i64;
let ne13 = 1 as i64;
let ne13 = 1i64;
let nb10 = 0i64;
let nb11 = 0i64;
@ -2169,6 +2011,7 @@ pub struct CallConvTranspose2dCfg<'a> {
pub kernel_offset: usize,
}
#[allow(clippy::too_many_arguments)]
pub fn call_conv_transpose2d(
device: &Device,
command_buffer: &CommandBufferRef,