Files
candle/candle-metal-kernels/src/utils.rs
2024-07-24 16:29:56 +02:00

202 lines
5.4 KiB
Rust

use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize};
use std::ffi::c_void;
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
/// actual total buffer length).
/// Then kernels can just do their op on their single point in the buffer.
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = (size + width - 1) / width;
let thread_group_count = MTLSize {
width: count,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
(thread_group_count, thread_group_size)
}
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
let mut sum = 0u64;
loop {
let presum = sum;
// Check all the pows
if dim0 >= (1 << (pows0 + 1)) {
pows0 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim1 >= (1 << (pows1 + 1)) {
pows1 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim2 >= (1 << (pows2 + 1)) {
pows2 += 1;
sum += 1;
}
if sum == presum || sum == 10 {
break;
}
}
MTLSize {
width: 1 << pows0,
height: 1 << pows1,
depth: 1 << pows2,
}
}
pub(crate) fn set_param<P: EncoderParam>(
encoder: &ComputeCommandEncoderRef,
position: u64,
data: P,
) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
pub(crate) trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
($type:ty) => {
impl EncoderParam for $type {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<$type>() as u64,
&data as *const $type as *const c_void,
);
}
}
};
}
primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u32);
primitive!(u64);
primitive!(f32);
pub struct BufferOffset<'a> {
pub buffer: &'a Buffer,
pub offset_in_bytes: usize,
}
impl<'a> BufferOffset<'a> {
pub fn zero_offset(buffer: &'a Buffer) -> Self {
Self {
buffer,
offset_in_bytes: 0,
}
}
}
impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of_val(data) as u64,
data.as_ptr() as *const c_void,
);
}
}
impl EncoderParam for &Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
impl<'a> EncoderParam for &BufferOffset<'a> {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
}
}
impl EncoderParam for &mut Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&mut Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
#[macro_export]
macro_rules! set_params {
($encoder:ident, ($($param:expr),+)) => (
let mut _index = 0;
$(
$crate::utils::set_param($encoder, _index, $param);
_index += 1;
)*
);
}
pub trait EncoderProvider {
type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
}
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
impl<'a> Drop for WrappedEncoder<'a> {
fn drop(&mut self) {
self.0.end_encoding()
}
}
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
&self.0
}
}
impl EncoderProvider for &metal::CommandBuffer {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
WrappedEncoder(self.new_compute_command_encoder())
}
}
impl EncoderProvider for &metal::CommandBufferRef {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
WrappedEncoder(self.new_compute_command_encoder())
}
}