Metal operational.

This commit is contained in:
Nicolas Patry
2023-11-18 00:52:38 +01:00
parent a0010898cc
commit 251c65f9f1
6 changed files with 69 additions and 27 deletions

View File

@ -300,9 +300,6 @@ pub fn call_unary_contiguous(
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -484,6 +481,7 @@ pub fn call_reduce_contiguous(
length: usize,
out_length: usize,
input: &Buffer,
input_offset: usize,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
@ -492,7 +490,7 @@ pub fn call_reduce_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, input, output));
set_params!(encoder, (length, elements_to_sum, (input,input_offset), output));
let thread_group_count = MTLSize {
width: out_length as u64,
@ -1228,6 +1226,7 @@ mod tests {
v.len(),
out_length,
&input,
0,
&mut output,
)
.unwrap();