Add a few metal gather ops. (#2740)

* Add a few metal gather ops.

* Fix some compilation issues.

* Adjust the tolerance.
This commit is contained in:
Laurent Mazare
2025-01-25 23:31:03 +01:00
committed by GitHub
parent 333d94a19a
commit 1a32107fab
5 changed files with 17 additions and 5 deletions

View File

@ -2029,7 +2029,7 @@ pub fn call_sdpa_vector_2pass(
)]));
let pipeline =
kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?;
kernels.load_pipeline_with_constants(device, Source::Sdpa, name_pass1, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
@ -2104,7 +2104,7 @@ pub fn call_sdpa_vector_2pass(
let b = (q_shape[0] * q_shape[1]) as i32;
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass2)?;
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name_pass2)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);