mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add fill kernel handler
This commit is contained in:
@ -14,7 +14,7 @@ metal = { version = "0.27.0", features = ["mps"]}
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
|
||||
[dev-dependencies]
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
rand = "0.8.5"
|
||||
|
@ -1,3 +1,4 @@
|
||||
use half::{bf16, f16};
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
@ -12,6 +13,7 @@ const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const FILL: &str = include_str!("fill.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
@ -45,7 +47,7 @@ fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64,
|
||||
/// Helper functions to create the various objects on the compute command encoder
|
||||
/// on a single line.
|
||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||
trait EncoderParam {
|
||||
pub trait EncoderParam {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||
}
|
||||
macro_rules! primitive {
|
||||
@ -62,7 +64,11 @@ macro_rules! primitive {
|
||||
};
|
||||
}
|
||||
primitive!(usize);
|
||||
primitive!(u8);
|
||||
primitive!(u32);
|
||||
primitive!(i64);
|
||||
primitive!(f16);
|
||||
primitive!(bf16);
|
||||
primitive!(f32);
|
||||
|
||||
impl<T> EncoderParam for &[T] {
|
||||
@ -117,6 +123,7 @@ pub enum Source {
|
||||
Reduce,
|
||||
Mfa,
|
||||
Conv,
|
||||
Fill,
|
||||
}
|
||||
|
||||
macro_rules! ops{
|
||||
@ -227,6 +234,7 @@ impl Kernels {
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Fill => FILL,
|
||||
Source::Conv => CONV,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
@ -1562,9 +1570,36 @@ pub fn call_upsample_nearest_2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
((m + b - 1) / b) as NSUInteger
|
||||
}
|
||||
|
||||
pub fn call_fill<D: EncoderParam>(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
elem_count: usize,
|
||||
buffer: &Buffer,
|
||||
value: D,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Fill, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger);
|
||||
|
||||
set_params!(encoder, (buffer, value, elem_count));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
@ -11,7 +11,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
|
||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let ptr = data.as_ptr() as *const core::ffi::c_void;
|
||||
let ptr = data.as_ptr() as *const c_void;
|
||||
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
||||
device.new_buffer_with_data(ptr, size, options)
|
||||
}
|
||||
@ -806,3 +806,48 @@ fn gemm() {
|
||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_fill<T: EncoderParam + Clone>(
|
||||
elem_count: usize,
|
||||
value: T,
|
||||
kernel_name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let buffer = new_buffer(&device, &vec![0.0f32; elem_count]);
|
||||
call_fill(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
kernel_name,
|
||||
elem_count,
|
||||
&buffer,
|
||||
value,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&buffer, elem_count)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fill() {
|
||||
fn assert_fill<T: EncoderParam + Copy + std::fmt::Debug + PartialEq>(
|
||||
value: T,
|
||||
name: &'static str,
|
||||
) {
|
||||
for i in 0..4 {
|
||||
assert_eq!(run_fill(8 ^ i, value, name), vec![value; 8 ^ i]);
|
||||
}
|
||||
}
|
||||
assert_fill(123u8, "fill_u8");
|
||||
assert_fill(456u32, "fill_u32");
|
||||
assert_fill(789i64, "fill_i64");
|
||||
assert_fill(f16::from_f32(1.23), "fill_f16");
|
||||
assert_fill(bf16::from_f32(4.56), "fill_bf16");
|
||||
assert_fill(7.89f32, "fill_f32");
|
||||
}
|
||||
|
Reference in New Issue
Block a user