mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a few metal gather ops. (#2740)
* Add a few metal gather ops. * Fix some compilation issues. * Adjust the tolerance.
This commit is contained in:
@ -1245,6 +1245,12 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U32, DType::F16) => "gather_u32_f16",
|
||||
(DType::U32, DType::BF16) => "gather_u32_bf16",
|
||||
(DType::U32, DType::U32) => "gather_u32_u32",
|
||||
(DType::U32, DType::I64) => "gather_u32_i64",
|
||||
(DType::I64, DType::F32) => "gather_i64_f32",
|
||||
(DType::I64, DType::F16) => "gather_i64_f16",
|
||||
(DType::I64, DType::BF16) => "gather_i64_bf16",
|
||||
(DType::I64, DType::U32) => "gather_i64_u32",
|
||||
(DType::I64, DType::I64) => "gather_i64_i64",
|
||||
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
|
@ -209,12 +209,18 @@ INDEX_OP(is_u8_f16, uint8_t, half)
|
||||
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||
#endif
|
||||
|
||||
GATHER_OP(gather_i64_f32, int64_t, float)
|
||||
GATHER_OP(gather_i64_f16, int64_t, half)
|
||||
GATHER_OP(gather_u32_f32, uint, float)
|
||||
GATHER_OP(gather_u32_f16, uint, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
GATHER_OP(gather_i64_bf16, int64_t, bfloat)
|
||||
GATHER_OP(gather_u32_bf16, uint, bfloat)
|
||||
#endif
|
||||
GATHER_OP(gather_i64_u32, int64_t, uint)
|
||||
GATHER_OP(gather_u32_u32, uint, uint)
|
||||
GATHER_OP(gather_i64_i64, int64_t, int64_t)
|
||||
GATHER_OP(gather_u32_i64, uint, int64_t)
|
||||
|
||||
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
|
||||
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
|
||||
|
@ -2029,7 +2029,7 @@ pub fn call_sdpa_vector_2pass(
|
||||
)]));
|
||||
|
||||
let pipeline =
|
||||
kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?;
|
||||
kernels.load_pipeline_with_constants(device, Source::Sdpa, name_pass1, constants)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -2104,7 +2104,7 @@ pub fn call_sdpa_vector_2pass(
|
||||
|
||||
let b = (q_shape[0] * q_shape[1]) as i32;
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass2)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name_pass2)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
@ -1404,7 +1404,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2);
|
||||
const constant size_t& v_stride, \
|
||||
const constant float& scale, \
|
||||
const constant float& softcapping, \
|
||||
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
|
||||
const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \
|
||||
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
|
||||
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
@ -1424,7 +1424,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2);
|
||||
const constant size_t& v_stride, \
|
||||
const constant float& scale, \
|
||||
const constant float& softcapping, \
|
||||
const device bool* mask [[function_constant(sdpa_vector_has_mask)]],, \
|
||||
const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \
|
||||
const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \
|
||||
const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
|
@ -116,7 +116,7 @@ mod metal_sdpa_tests {
|
||||
.sum_all()?
|
||||
.to_scalar()?;
|
||||
|
||||
assert!(error <= 0.0004, "{}", error);
|
||||
assert!(error <= 0.0005, "{}", error);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user