Adding some half kernels.

This commit is contained in:
Nicolas Patry
2023-11-11 13:30:21 +01:00
parent e02f1912bb
commit 54355ff997
4 changed files with 165 additions and 12 deletions

View File

@ -460,6 +460,46 @@ pub fn call_cast_contiguous(
Ok(())
}
pub fn call_cast_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
input: &Buffer,
input_strides: &[usize],
input_offset: usize,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
set_params!(
encoder,
(length, shape, input_strides, (input, input_offset), output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_reduce_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@ -565,13 +605,14 @@ pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
size: usize,
input: &Buffer,
output: &mut Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Affine, "affine_float")?;
let func = kernels.load_function(device, Source::Affine, name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
@ -596,6 +637,7 @@ pub fn call_affine_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: &Buffer,
input_stride: &[usize],
@ -604,7 +646,7 @@ pub fn call_affine_strided(
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Affine, "affine_float_strided")?;
let func = kernels.load_function(device, Source::Affine, name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));