mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00

* Add a scalar enum type. * Add a bit more to the scalar type. * Small tweak. * More scalar usage.
228 lines
6.0 KiB
Rust
228 lines
6.0 KiB
Rust
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize};
|
|
use std::ffi::c_void;
|
|
|
|
/// Most kernels apply similarly across the tensors
|
|
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
|
/// actual total buffer length).
|
|
/// Then kernels can just do their op on their single point in the buffer.
|
|
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
|
let size = length as u64;
|
|
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
|
let count = size.div_ceil(width);
|
|
let thread_group_count = MTLSize {
|
|
width: count,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let thread_group_size = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
(thread_group_count, thread_group_size)
|
|
}
|
|
|
|
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
|
|
pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
|
let mut pows0 = 0u64;
|
|
let mut pows1 = 0u64;
|
|
let mut pows2 = 0u64;
|
|
let mut sum = 0u64;
|
|
loop {
|
|
let presum = sum;
|
|
// Check all the pows
|
|
if dim0 >= (1 << (pows0 + 1)) {
|
|
pows0 += 1;
|
|
sum += 1;
|
|
}
|
|
if sum == 10 {
|
|
break;
|
|
}
|
|
if dim1 >= (1 << (pows1 + 1)) {
|
|
pows1 += 1;
|
|
sum += 1;
|
|
}
|
|
if sum == 10 {
|
|
break;
|
|
}
|
|
if dim2 >= (1 << (pows2 + 1)) {
|
|
pows2 += 1;
|
|
sum += 1;
|
|
}
|
|
if sum == presum || sum == 10 {
|
|
break;
|
|
}
|
|
}
|
|
MTLSize {
|
|
width: 1 << pows0,
|
|
height: 1 << pows1,
|
|
depth: 1 << pows2,
|
|
}
|
|
}
|
|
|
|
pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
|
|
<P as EncoderParam>::set_param(encoder, position, data)
|
|
}
|
|
|
|
/// Helper functions to create the various objects on the compute command encoder
|
|
/// on a single line.
|
|
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
|
pub trait EncoderParam {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
|
}
|
|
macro_rules! primitive {
|
|
($type:ty) => {
|
|
impl EncoderParam for $type {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_bytes(
|
|
position,
|
|
core::mem::size_of::<$type>() as u64,
|
|
&data as *const $type as *const c_void,
|
|
);
|
|
}
|
|
}
|
|
};
|
|
}
|
|
primitive!(bool);
|
|
primitive!(usize);
|
|
primitive!(i32);
|
|
primitive!(i64);
|
|
primitive!(u8);
|
|
primitive!(u32);
|
|
primitive!(u64);
|
|
primitive!(f32);
|
|
primitive!(f64);
|
|
primitive!(half::bf16);
|
|
primitive!(half::f16);
|
|
|
|
pub struct BufferOffset<'a> {
|
|
pub buffer: &'a Buffer,
|
|
pub offset_in_bytes: usize,
|
|
}
|
|
|
|
impl<'a> BufferOffset<'a> {
|
|
pub fn zero_offset(buffer: &'a Buffer) -> Self {
|
|
Self {
|
|
buffer,
|
|
offset_in_bytes: 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T> EncoderParam for &[T] {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_bytes(
|
|
position,
|
|
core::mem::size_of_val(data) as u64,
|
|
data.as_ptr() as *const c_void,
|
|
);
|
|
}
|
|
}
|
|
|
|
impl EncoderParam for &Buffer {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data), 0);
|
|
}
|
|
}
|
|
|
|
impl EncoderParam for (&Buffer, usize) {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
|
}
|
|
}
|
|
|
|
impl EncoderParam for &BufferOffset<'_> {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
|
|
}
|
|
}
|
|
|
|
impl EncoderParam for &mut Buffer {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data), 0);
|
|
}
|
|
}
|
|
|
|
impl EncoderParam for (&mut Buffer, usize) {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
|
}
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! set_params {
|
|
($encoder:ident, ($($param:expr),+)) => (
|
|
let mut _index = 0;
|
|
$(
|
|
$crate::utils::set_param($encoder, _index, $param);
|
|
_index += 1;
|
|
)*
|
|
);
|
|
}
|
|
|
|
pub trait EncoderProvider {
|
|
type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
|
|
where
|
|
Self: 'a;
|
|
fn encoder(&self) -> Self::Encoder<'_>;
|
|
}
|
|
|
|
pub struct WrappedEncoder<'a> {
|
|
inner: &'a ComputeCommandEncoderRef,
|
|
end_encoding_on_drop: bool,
|
|
}
|
|
|
|
impl Drop for WrappedEncoder<'_> {
|
|
fn drop(&mut self) {
|
|
if self.end_encoding_on_drop {
|
|
self.inner.end_encoding()
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'_> {
|
|
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
|
|
self.inner
|
|
}
|
|
}
|
|
|
|
impl EncoderProvider for &metal::CommandBuffer {
|
|
type Encoder<'a>
|
|
= WrappedEncoder<'a>
|
|
where
|
|
Self: 'a;
|
|
fn encoder(&self) -> Self::Encoder<'_> {
|
|
WrappedEncoder {
|
|
inner: self.new_compute_command_encoder(),
|
|
end_encoding_on_drop: true,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl EncoderProvider for &metal::CommandBufferRef {
|
|
type Encoder<'a>
|
|
= WrappedEncoder<'a>
|
|
where
|
|
Self: 'a;
|
|
fn encoder(&self) -> Self::Encoder<'_> {
|
|
WrappedEncoder {
|
|
inner: self.new_compute_command_encoder(),
|
|
end_encoding_on_drop: true,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl EncoderProvider for &ComputeCommandEncoderRef {
|
|
type Encoder<'a>
|
|
= WrappedEncoder<'a>
|
|
where
|
|
Self: 'a;
|
|
fn encoder(&self) -> Self::Encoder<'_> {
|
|
WrappedEncoder {
|
|
inner: self,
|
|
end_encoding_on_drop: false,
|
|
}
|
|
}
|
|
}
|