Add fill kernel handler

This commit is contained in:
Ivar Flakstad
2023-12-29 12:27:12 +01:00
parent fd9bf3bcdd
commit 0a29d2e9b8
3 changed files with 83 additions and 3 deletions

View File

@ -1,3 +1,4 @@
use half::{bf16, f16};
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
@ -12,6 +13,7 @@ const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const FILL: &str = include_str!("fill.metal");
const REDUCE: &str = include_str!("reduce.metal");
const CONV: &str = include_str!("conv.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
@ -45,7 +47,7 @@ fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64,
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
trait EncoderParam {
pub trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
@ -62,7 +64,11 @@ macro_rules! primitive {
};
}
primitive!(usize);
primitive!(u8);
primitive!(u32);
primitive!(i64);
primitive!(f16);
primitive!(bf16);
primitive!(f32);
impl<T> EncoderParam for &[T] {
@ -117,6 +123,7 @@ pub enum Source {
Reduce,
Mfa,
Conv,
Fill,
}
macro_rules! ops{
@ -227,6 +234,7 @@ impl Kernels {
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Fill => FILL,
Source::Conv => CONV,
Source::Mfa => panic!("Invalid lib"),
}
@ -1562,9 +1570,36 @@ pub fn call_upsample_nearest_2d(
Ok(())
}
#[inline]
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
pub fn call_fill<D: EncoderParam>(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
elem_count: usize,
buffer: &Buffer,
value: D,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger);
set_params!(encoder, (buffer, value, elem_count));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests;