Addressing a lot of comments.

This commit is contained in:
Nicolas Patry
2023-12-15 13:06:04 +01:00
parent aa04015098
commit 6bc92e63cb
4 changed files with 33 additions and 20 deletions

View File

@ -597,6 +597,7 @@ pub fn call_last_softmax(
length: usize,
elements_to_sum: usize,
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
@ -604,7 +605,10 @@ pub fn call_last_softmax(
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, input, output));
set_params!(
encoder,
(length, elements_to_sum, (input, input_offset), output)
);
let out_length = length / elements_to_sum;