mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Addressing a lot of comments.
This commit is contained in:
@ -597,6 +597,7 @@ pub fn call_last_softmax(
|
||||
length: usize,
|
||||
elements_to_sum: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
@ -604,7 +605,10 @@ pub fn call_last_softmax(
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||
set_params!(
|
||||
encoder,
|
||||
(length, elements_to_sum, (input, input_offset), output)
|
||||
);
|
||||
|
||||
let out_length = length / elements_to_sum;
|
||||
|
||||
|
Reference in New Issue
Block a user