mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add support for index u8/i64 and input f16/bf16 scatter-add on metal (#1849)
* add support and tests for scatter add on metal * add support for all datatypes
This commit is contained in:
@ -1128,7 +1128,15 @@ impl BackendStorage for MetalStorage {
|
|||||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||||
};
|
};
|
||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
|
(DType::U8, DType::F32) => "sa_u8_f32",
|
||||||
|
(DType::U8, DType::F16) => "sa_u8_f16",
|
||||||
|
(DType::U8, DType::BF16) => "sa_u8_bf16",
|
||||||
(DType::U32, DType::F32) => "sa_u32_f32",
|
(DType::U32, DType::F32) => "sa_u32_f32",
|
||||||
|
(DType::U32, DType::F16) => "sa_u32_f16",
|
||||||
|
(DType::U32, DType::BF16) => "sa_u32_bf16",
|
||||||
|
(DType::I64, DType::F32) => "sa_i64_f32",
|
||||||
|
(DType::I64, DType::F16) => "sa_i64_f16",
|
||||||
|
(DType::I64, DType::BF16) => "sa_i64_bf16",
|
||||||
_ => Err(MetalError::UnexpectedDType {
|
_ => Err(MetalError::UnexpectedDType {
|
||||||
msg: "scatter-add ids should be u8/u32/i64",
|
msg: "scatter-add ids should be u8/u32/i64",
|
||||||
expected: DType::U32,
|
expected: DType::U32,
|
||||||
|
@ -167,11 +167,16 @@ kernel void NAME( \
|
|||||||
|
|
||||||
INDEX_OP(is_u32_f32, uint, float)
|
INDEX_OP(is_u32_f32, uint, float)
|
||||||
INDEX_OP(is_u32_f16, uint, half)
|
INDEX_OP(is_u32_f16, uint, half)
|
||||||
|
|
||||||
GATHER_OP(gather_u32_f32, uint, float)
|
GATHER_OP(gather_u32_f32, uint, float)
|
||||||
GATHER_OP(gather_u32_f16, uint, half)
|
GATHER_OP(gather_u32_f16, uint, half)
|
||||||
SCATTER_ADD_OP(sa_u32_f32, uint, float)
|
|
||||||
SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
|
||||||
|
|
||||||
|
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
|
||||||
|
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
|
||||||
|
SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
|
||||||
|
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
|
||||||
|
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
|
||||||
|
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
|
||||||
|
|
||||||
#if defined(__HAVE_BFLOAT__)
|
#if defined(__HAVE_BFLOAT__)
|
||||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||||
@ -180,6 +185,10 @@ INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
|||||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||||
|
|
||||||
|
SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat)
|
||||||
|
SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
||||||
|
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
|
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
|
||||||
|
@ -1066,3 +1066,107 @@ fn random() {
|
|||||||
validate_random!(f16);
|
validate_random!(f16);
|
||||||
validate_random!(bf16);
|
validate_random!(bf16);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||||
|
input: &[T],
|
||||||
|
ids: &[I],
|
||||||
|
shape: &[usize],
|
||||||
|
dim: usize,
|
||||||
|
name: &'static str,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
let input_buffer = new_buffer(&device, input);
|
||||||
|
let ids_buffer = new_buffer(&device, ids);
|
||||||
|
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
|
||||||
|
call_scatter_add(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
shape,
|
||||||
|
shape,
|
||||||
|
dim,
|
||||||
|
&input_buffer,
|
||||||
|
0,
|
||||||
|
&ids_buffer,
|
||||||
|
0,
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
read_to_vec(&output, input.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scatter_add() {
|
||||||
|
let ids_u8 = [0u8, 0, 1, 0, 2, 2, 3, 3];
|
||||||
|
let ids_u32 = [0u32, 0, 1, 0, 2, 2, 3, 3];
|
||||||
|
let ids_i64 = [0i64, 0, 1, 0, 2, 2, 3, 3];
|
||||||
|
|
||||||
|
let input_f32 = [5.0f32, 1.0, 7.0, 2.0, 3.0, 2.0, 1.0, 3.0];
|
||||||
|
let input_f16 = input_f32
|
||||||
|
.iter()
|
||||||
|
.map(|v| f16::from_f32(*v))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let input_bf16 = input_f32
|
||||||
|
.iter()
|
||||||
|
.map(|v| bf16::from_f32(*v))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let output_dim1_f32 = vec![8.0, 7.0, 5.0, 4.0, 0.0, 0.0, 0.0, 0.0];
|
||||||
|
let output_dim1_f16 = output_dim1_f32
|
||||||
|
.iter()
|
||||||
|
.map(|v| f16::from_f32(*v))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let output_dim1_bf16 = output_dim1_f32
|
||||||
|
.iter()
|
||||||
|
.map(|v| bf16::from_f32(*v))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let output_dim2_f32 = vec![5.0, 3.0, 7.0, 0.0, 3.0, 2.0, 1.0, 3.0];
|
||||||
|
let output_dim2_f16 = output_dim2_f32
|
||||||
|
.iter()
|
||||||
|
.map(|v| f16::from_f32(*v))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let output_dim2_bf16 = output_dim2_f32
|
||||||
|
.iter()
|
||||||
|
.map(|v| bf16::from_f32(*v))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
for (shape, output_f32, output_f16, output_bf16) in [
|
||||||
|
(vec![8], output_dim1_f32, output_dim1_f16, output_dim1_bf16),
|
||||||
|
(
|
||||||
|
vec![4, 2],
|
||||||
|
output_dim2_f32,
|
||||||
|
output_dim2_f16,
|
||||||
|
output_dim2_bf16,
|
||||||
|
),
|
||||||
|
] {
|
||||||
|
for results in [
|
||||||
|
run_scatter_add(&input_f32, &ids_u8, &shape, 0, "sa_u8_f32"),
|
||||||
|
run_scatter_add(&input_f32, &ids_u32, &shape, 0, "sa_u32_f32"),
|
||||||
|
run_scatter_add(&input_f32, &ids_i64, &shape, 0, "sa_i64_f32"),
|
||||||
|
] {
|
||||||
|
assert_eq!(results, output_f32);
|
||||||
|
}
|
||||||
|
for results in [
|
||||||
|
run_scatter_add(&input_f16, &ids_u8, &shape, 0, "sa_u8_f16"),
|
||||||
|
run_scatter_add(&input_f16, &ids_u32, &shape, 0, "sa_u32_f16"),
|
||||||
|
run_scatter_add(&input_f16, &ids_i64, &shape, 0, "sa_i64_f16"),
|
||||||
|
] {
|
||||||
|
assert_eq!(results, output_f16);
|
||||||
|
}
|
||||||
|
for results in [
|
||||||
|
run_scatter_add(&input_bf16, &ids_u8, &shape, 0, "sa_u8_bf16"),
|
||||||
|
run_scatter_add(&input_bf16, &ids_u32, &shape, 0, "sa_u32_bf16"),
|
||||||
|
run_scatter_add(&input_bf16, &ids_i64, &shape, 0, "sa_i64_bf16"),
|
||||||
|
] {
|
||||||
|
assert_eq!(results, output_bf16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user