mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add a few metal gather ops. (#2740)
* Add a few metal gather ops. * Fix some compilation issues. * Adjust the tolerance.
This commit is contained in:
@ -2029,7 +2029,7 @@ pub fn call_sdpa_vector_2pass(
|
||||
)]));
|
||||
|
||||
let pipeline =
|
||||
kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?;
|
||||
kernels.load_pipeline_with_constants(device, Source::Sdpa, name_pass1, constants)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -2104,7 +2104,7 @@ pub fn call_sdpa_vector_2pass(
|
||||
|
||||
let b = (q_shape[0] * q_shape[1]) as i32;
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass2)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name_pass2)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
Reference in New Issue
Block a user