mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module. * Use the BufferOffset for unary ops. * Fix clippy lints. * Use the new BufferOffset. * Adapt the binary ops. * Affine. * More ops (powf, elu, cast).
This commit is contained in:
162
candle-metal-kernels/src/utils.rs
Normal file
162
candle-metal-kernels/src/utils.rs
Normal file
@ -0,0 +1,162 @@
|
||||
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize};
|
||||
use std::ffi::c_void;
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
/// actual total buffer length).
|
||||
/// Then kernels can just do their op on their single point in the buffer.
|
||||
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||
let size = length as u64;
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
||||
let count = (size + width - 1) / width;
|
||||
let thread_group_count = MTLSize {
|
||||
width: count,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
(thread_group_count, thread_group_size)
|
||||
}
|
||||
|
||||
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
|
||||
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
||||
let mut pows0 = 0u64;
|
||||
let mut pows1 = 0u64;
|
||||
let mut pows2 = 0u64;
|
||||
let mut sum = 0u64;
|
||||
loop {
|
||||
let presum = sum;
|
||||
// Check all the pows
|
||||
if dim0 >= (1 << (pows0 + 1)) {
|
||||
pows0 += 1;
|
||||
sum += 1;
|
||||
}
|
||||
if sum == 10 {
|
||||
break;
|
||||
}
|
||||
if dim1 >= (1 << (pows1 + 1)) {
|
||||
pows1 += 1;
|
||||
sum += 1;
|
||||
}
|
||||
if sum == 10 {
|
||||
break;
|
||||
}
|
||||
if dim2 >= (1 << (pows2 + 1)) {
|
||||
pows2 += 1;
|
||||
sum += 1;
|
||||
}
|
||||
if sum == presum || sum == 10 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
MTLSize {
|
||||
width: 1 << pows0,
|
||||
height: 1 << pows1,
|
||||
depth: 1 << pows2,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_param<P: EncoderParam>(
|
||||
encoder: &ComputeCommandEncoderRef,
|
||||
position: u64,
|
||||
data: P,
|
||||
) {
|
||||
<P as EncoderParam>::set_param(encoder, position, data)
|
||||
}
|
||||
|
||||
/// Helper functions to create the various objects on the compute command encoder
|
||||
/// on a single line.
|
||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||
pub(crate) trait EncoderParam {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||
}
|
||||
macro_rules! primitive {
|
||||
($type:ty) => {
|
||||
impl EncoderParam for $type {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
core::mem::size_of::<$type>() as u64,
|
||||
&data as *const $type as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
|
||||
pub struct BufferOffset<'a> {
|
||||
pub buffer: &'a Buffer,
|
||||
pub offset_in_bytes: usize,
|
||||
}
|
||||
|
||||
impl<'a> BufferOffset<'a> {
|
||||
pub fn zero_offset(buffer: &'a Buffer) -> Self {
|
||||
Self {
|
||||
buffer,
|
||||
offset_in_bytes: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> EncoderParam for &[T] {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
core::mem::size_of_val(data) as u64,
|
||||
data.as_ptr() as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for &Buffer {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data), 0);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for (&Buffer, usize) {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> EncoderParam for &BufferOffset<'a> {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for &mut Buffer {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data), 0);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for (&mut Buffer, usize) {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! set_params {
|
||||
($encoder:ident, ($($param:expr),+)) => (
|
||||
let mut _index = 0;
|
||||
$(
|
||||
$crate::utils::set_param($encoder, _index, $param);
|
||||
_index += 1;
|
||||
)*
|
||||
);
|
||||
}
|
Reference in New Issue
Block a user