mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
add test for index add and add missing match statements (#1862)
This commit is contained in:
@ -1242,9 +1242,29 @@ impl BackendStorage for MetalStorage {
|
|||||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
};
|
};
|
||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
|
(DType::I64, DType::BF16) => "ia_i64_bf16",
|
||||||
|
(DType::I64, DType::F16) => "ia_i64_f16",
|
||||||
|
(DType::I64, DType::F32) => "ia_i64_f32",
|
||||||
|
(DType::I64, DType::I64) => "ia_i64_i64",
|
||||||
|
(DType::I64, DType::U32) => "ia_i64_u32",
|
||||||
|
(DType::I64, DType::U8) => "ia_i64_u8",
|
||||||
|
|
||||||
|
(DType::U32, DType::BF16) => "ia_u32_bf16",
|
||||||
|
(DType::U32, DType::F16) => "ia_u32_f16",
|
||||||
(DType::U32, DType::F32) => "ia_u32_f32",
|
(DType::U32, DType::F32) => "ia_u32_f32",
|
||||||
|
(DType::U32, DType::I64) => "ia_u32_i64",
|
||||||
|
(DType::U32, DType::U32) => "ia_u32_u32",
|
||||||
|
(DType::U32, DType::U8) => "ia_u32_u8",
|
||||||
|
|
||||||
|
(DType::U8, DType::BF16) => "ia_u8_bf16",
|
||||||
|
(DType::U8, DType::F16) => "ia_u8_f16",
|
||||||
|
(DType::U8, DType::F32) => "ia_u8_f32",
|
||||||
|
(DType::U8, DType::I64) => "ia_u8_i64",
|
||||||
|
(DType::U8, DType::U32) => "ia_u8_u32",
|
||||||
|
(DType::U8, DType::U8) => "ia_u8_u8",
|
||||||
|
|
||||||
_ => Err(MetalError::UnexpectedDType {
|
_ => Err(MetalError::UnexpectedDType {
|
||||||
msg: "index-add ids should be u32",
|
msg: "index-add ids should be u8/u32/i64",
|
||||||
expected: DType::U32,
|
expected: DType::U32,
|
||||||
got: ids.dtype(),
|
got: ids.dtype(),
|
||||||
})?,
|
})?,
|
||||||
|
@ -167,6 +167,10 @@ kernel void NAME( \
|
|||||||
|
|
||||||
INDEX_OP(is_u32_f32, uint, float)
|
INDEX_OP(is_u32_f32, uint, float)
|
||||||
INDEX_OP(is_u32_f16, uint, half)
|
INDEX_OP(is_u32_f16, uint, half)
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||||
|
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||||
|
#endif
|
||||||
|
|
||||||
GATHER_OP(gather_u32_f32, uint, float)
|
GATHER_OP(gather_u32_f32, uint, float)
|
||||||
GATHER_OP(gather_u32_f16, uint, half)
|
GATHER_OP(gather_u32_f16, uint, half)
|
||||||
@ -177,34 +181,38 @@ SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
|
|||||||
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
|
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
|
||||||
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
|
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
|
||||||
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
|
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
|
||||||
|
|
||||||
#if defined(__HAVE_BFLOAT__)
|
#if defined(__HAVE_BFLOAT__)
|
||||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
|
||||||
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
|
||||||
|
|
||||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
|
||||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
|
||||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
|
||||||
|
|
||||||
SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat)
|
SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat)
|
||||||
SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
||||||
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
|
// i64
|
||||||
INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
|
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
|
||||||
|
|
||||||
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
||||||
INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
|
|
||||||
INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
|
INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
|
||||||
INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
|
INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
|
||||||
|
INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// u32
|
||||||
|
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
|
||||||
INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
|
INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
|
||||||
INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
|
|
||||||
INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
|
INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
|
||||||
INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
|
INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
|
||||||
|
INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// u8
|
||||||
|
INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
|
||||||
INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
|
INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
|
||||||
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
|
|
||||||
INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
|
|
||||||
INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
|
INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
|
||||||
|
INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
|
||||||
|
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||||
|
#endif
|
@ -1252,3 +1252,119 @@ fn scatter_add() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||||
|
left: &[T],
|
||||||
|
right: &[T],
|
||||||
|
indices: &[I],
|
||||||
|
shape: &[usize],
|
||||||
|
dim: usize,
|
||||||
|
name: &'static str,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let input_buffer = new_buffer(&device, right);
|
||||||
|
let output = new_buffer(&device, left);
|
||||||
|
let indices_buffer = new_buffer(&device, indices);
|
||||||
|
call_index_add(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
shape,
|
||||||
|
shape,
|
||||||
|
shape,
|
||||||
|
dim,
|
||||||
|
&input_buffer,
|
||||||
|
0,
|
||||||
|
&indices_buffer,
|
||||||
|
0,
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
read_to_vec(&output, left.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_add() {
|
||||||
|
let left = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
let right = vec![1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0];
|
||||||
|
let indices = vec![0u32, 1, 0, 1, 0, 1];
|
||||||
|
let shape = vec![6];
|
||||||
|
|
||||||
|
// u32, f32
|
||||||
|
{
|
||||||
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f32");
|
||||||
|
assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// u32, f16
|
||||||
|
{
|
||||||
|
let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f16");
|
||||||
|
assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// u32, bf16
|
||||||
|
{
|
||||||
|
let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_bf16");
|
||||||
|
assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// u8, f32
|
||||||
|
{
|
||||||
|
let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
|
||||||
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f32");
|
||||||
|
assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// u8, f16
|
||||||
|
{
|
||||||
|
let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
|
||||||
|
let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f16");
|
||||||
|
assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// u8, bf16
|
||||||
|
{
|
||||||
|
let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
|
||||||
|
let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_bf16");
|
||||||
|
assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// i64, f32
|
||||||
|
{
|
||||||
|
let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||||
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f32");
|
||||||
|
assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// i64, f16
|
||||||
|
{
|
||||||
|
let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||||
|
let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f16");
|
||||||
|
assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// i64, bf16
|
||||||
|
{
|
||||||
|
let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||||
|
let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||||
|
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_bf16");
|
||||||
|
assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user