Add fill kernel handler

This commit is contained in:
Ivar Flakstad
2023-12-29 12:27:12 +01:00
parent fd9bf3bcdd
commit 0a29d2e9b8
3 changed files with 83 additions and 3 deletions

View File

@ -11,7 +11,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const core::ffi::c_void;
let ptr = data.as_ptr() as *const c_void;
let size = (data.len() * std::mem::size_of::<T>()) as u64;
device.new_buffer_with_data(ptr, size, options)
}
@ -806,3 +806,48 @@ fn gemm() {
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
}
fn run_fill<T: EncoderParam + Clone>(
elem_count: usize,
value: T,
kernel_name: &'static str,
) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let buffer = new_buffer(&device, &vec![0.0f32; elem_count]);
call_fill(
&device,
command_buffer,
&kernels,
kernel_name,
elem_count,
&buffer,
value,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&buffer, elem_count)
}
#[test]
fn fill() {
fn assert_fill<T: EncoderParam + Copy + std::fmt::Debug + PartialEq>(
value: T,
name: &'static str,
) {
for i in 0..4 {
assert_eq!(run_fill(8 ^ i, value, name), vec![value; 8 ^ i]);
}
}
assert_fill(123u8, "fill_u8");
assert_fill(456u32, "fill_u32");
assert_fill(789i64, "fill_i64");
assert_fill(f16::from_f32(1.23), "fill_f16");
assert_fill(bf16::from_f32(4.56), "fill_bf16");
assert_fill(7.89f32, "fill_f32");
}