mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
202 lines
5.4 KiB
Rust
202 lines
5.4 KiB
Rust
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize};
|
|
use std::ffi::c_void;
|
|
|
|
/// Most kernels apply similarly across the tensors
|
|
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
|
/// actual total buffer length).
|
|
/// Then kernels can just do their op on their single point in the buffer.
|
|
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
|
let size = length as u64;
|
|
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
|
let count = (size + width - 1) / width;
|
|
let thread_group_count = MTLSize {
|
|
width: count,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let thread_group_size = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
(thread_group_count, thread_group_size)
|
|
}
|
|
|
|
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
|
|
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
|
let mut pows0 = 0u64;
|
|
let mut pows1 = 0u64;
|
|
let mut pows2 = 0u64;
|
|
let mut sum = 0u64;
|
|
loop {
|
|
let presum = sum;
|
|
// Check all the pows
|
|
if dim0 >= (1 << (pows0 + 1)) {
|
|
pows0 += 1;
|
|
sum += 1;
|
|
}
|
|
if sum == 10 {
|
|
break;
|
|
}
|
|
if dim1 >= (1 << (pows1 + 1)) {
|
|
pows1 += 1;
|
|
sum += 1;
|
|
}
|
|
if sum == 10 {
|
|
break;
|
|
}
|
|
if dim2 >= (1 << (pows2 + 1)) {
|
|
pows2 += 1;
|
|
sum += 1;
|
|
}
|
|
if sum == presum || sum == 10 {
|
|
break;
|
|
}
|
|
}
|
|
MTLSize {
|
|
width: 1 << pows0,
|
|
height: 1 << pows1,
|
|
depth: 1 << pows2,
|
|
}
|
|
}
|
|
|
|
pub(crate) fn set_param<P: EncoderParam>(
|
|
encoder: &ComputeCommandEncoderRef,
|
|
position: u64,
|
|
data: P,
|
|
) {
|
|
<P as EncoderParam>::set_param(encoder, position, data)
|
|
}
|
|
|
|
/// Helper functions to create the various objects on the compute command encoder
|
|
/// on a single line.
|
|
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
|
pub(crate) trait EncoderParam {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
|
}
|
|
macro_rules! primitive {
|
|
($type:ty) => {
|
|
impl EncoderParam for $type {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_bytes(
|
|
position,
|
|
core::mem::size_of::<$type>() as u64,
|
|
&data as *const $type as *const c_void,
|
|
);
|
|
}
|
|
}
|
|
};
|
|
}
|
|
primitive!(bool);
|
|
primitive!(usize);
|
|
primitive!(i32);
|
|
primitive!(i64);
|
|
primitive!(u32);
|
|
primitive!(u64);
|
|
primitive!(f32);
|
|
|
|
pub struct BufferOffset<'a> {
|
|
pub buffer: &'a Buffer,
|
|
pub offset_in_bytes: usize,
|
|
}
|
|
|
|
impl<'a> BufferOffset<'a> {
|
|
pub fn zero_offset(buffer: &'a Buffer) -> Self {
|
|
Self {
|
|
buffer,
|
|
offset_in_bytes: 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T> EncoderParam for &[T] {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_bytes(
|
|
position,
|
|
core::mem::size_of_val(data) as u64,
|
|
data.as_ptr() as *const c_void,
|
|
);
|
|
}
|
|
}
|
|
|
|
impl EncoderParam for &Buffer {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data), 0);
|
|
}
|
|
}
|
|
|
|
impl EncoderParam for (&Buffer, usize) {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
|
}
|
|
}
|
|
|
|
impl<'a> EncoderParam for &BufferOffset<'a> {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
|
|
}
|
|
}
|
|
|
|
impl EncoderParam for &mut Buffer {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data), 0);
|
|
}
|
|
}
|
|
|
|
impl EncoderParam for (&mut Buffer, usize) {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
|
}
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! set_params {
|
|
($encoder:ident, ($($param:expr),+)) => (
|
|
let mut _index = 0;
|
|
$(
|
|
$crate::utils::set_param($encoder, _index, $param);
|
|
_index += 1;
|
|
)*
|
|
);
|
|
}
|
|
|
|
pub trait EncoderProvider {
|
|
type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
|
|
where
|
|
Self: 'a;
|
|
fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
|
|
}
|
|
|
|
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
|
|
|
|
impl<'a> Drop for WrappedEncoder<'a> {
|
|
fn drop(&mut self) {
|
|
self.0.end_encoding()
|
|
}
|
|
}
|
|
|
|
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
|
|
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
impl EncoderProvider for &metal::CommandBuffer {
|
|
type Encoder<'a> = WrappedEncoder<'a>
|
|
where
|
|
Self: 'a;
|
|
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
|
|
WrappedEncoder(self.new_compute_command_encoder())
|
|
}
|
|
}
|
|
|
|
impl EncoderProvider for &metal::CommandBufferRef {
|
|
type Encoder<'a> = WrappedEncoder<'a>
|
|
where
|
|
Self: 'a;
|
|
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
|
|
WrappedEncoder(self.new_compute_command_encoder())
|
|
}
|
|
}
|