Add a few metal gather ops. (#2740)

* Add a few metal gather ops.

* Fix some compilation issues.

* Adjust the tolerance.
This commit is contained in:
Laurent Mazare
2025-01-25 23:31:03 +01:00
committed by GitHub
parent 333d94a19a
commit 1a32107fab
5 changed files with 17 additions and 5 deletions

View File

@ -1245,6 +1245,12 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F16) => "gather_u32_f16",
(DType::U32, DType::BF16) => "gather_u32_bf16",
(DType::U32, DType::U32) => "gather_u32_u32",
(DType::U32, DType::I64) => "gather_u32_i64",
(DType::I64, DType::F32) => "gather_i64_f32",
(DType::I64, DType::F16) => "gather_i64_f16",
(DType::I64, DType::BF16) => "gather_i64_bf16",
(DType::I64, DType::U32) => "gather_i64_u32",
(DType::I64, DType::I64) => "gather_i64_i64",
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;

View File

@ -209,12 +209,18 @@ INDEX_OP(is_u8_f16, uint8_t, half)
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
#endif
GATHER_OP(gather_i64_f32, int64_t, float)
GATHER_OP(gather_i64_f16, int64_t, half)
GATHER_OP(gather_u32_f32, uint, float)
GATHER_OP(gather_u32_f16, uint, half)
#if defined(__HAVE_BFLOAT__)
GATHER_OP(gather_i64_bf16, int64_t, bfloat)
GATHER_OP(gather_u32_bf16, uint, bfloat)
#endif
GATHER_OP(gather_i64_u32, int64_t, uint)
GATHER_OP(gather_u32_u32, uint, uint)
GATHER_OP(gather_i64_i64, int64_t, int64_t)
GATHER_OP(gather_u32_i64, uint, int64_t)
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)

View File

@ -2029,7 +2029,7 @@ pub fn call_sdpa_vector_2pass(
)]));
let pipeline =
kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?;
kernels.load_pipeline_with_constants(device, Source::Sdpa, name_pass1, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
@ -2104,7 +2104,7 @@ pub fn call_sdpa_vector_2pass(
let b = (q_shape[0] * q_shape[1]) as i32;
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass2)?;
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name_pass2)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

View File

@ -1404,7 +1404,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2);
const constant size_t& v_stride, \
const constant float& scale, \
const constant float& softcapping, \
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
uint3 tid [[threadgroup_position_in_grid]], \
@ -1424,7 +1424,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2);
const constant size_t& v_stride, \
const constant float& scale, \
const constant float& softcapping, \
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
uint3 tid [[threadgroup_position_in_grid]], \

View File

@ -116,7 +116,7 @@ mod metal_sdpa_tests {
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0004, "{}", error);
assert!(error <= 0.0005, "{}", error);
Ok(())
}