mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
add test for index add and add missing match statements (#1862)
This commit is contained in:
@ -1252,3 +1252,119 @@ fn scatter_add() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
left: &[T],
|
||||
right: &[T],
|
||||
indices: &[I],
|
||||
shape: &[usize],
|
||||
dim: usize,
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input_buffer = new_buffer(&device, right);
|
||||
let output = new_buffer(&device, left);
|
||||
let indices_buffer = new_buffer(&device, indices);
|
||||
call_index_add(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
shape,
|
||||
shape,
|
||||
dim,
|
||||
&input_buffer,
|
||||
0,
|
||||
&indices_buffer,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
read_to_vec(&output, left.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_add() {
|
||||
let left = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let right = vec![1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0];
|
||||
let indices = vec![0u32, 1, 0, 1, 0, 1];
|
||||
let shape = vec![6];
|
||||
|
||||
// u32, f32
|
||||
{
|
||||
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f32");
|
||||
assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
// u32, f16
|
||||
{
|
||||
let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f16");
|
||||
assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
// u32, bf16
|
||||
{
|
||||
let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_bf16");
|
||||
assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
// u8, f32
|
||||
{
|
||||
let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
|
||||
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f32");
|
||||
assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
// u8, f16
|
||||
{
|
||||
let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
|
||||
let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f16");
|
||||
assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
// u8, bf16
|
||||
{
|
||||
let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
|
||||
let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_bf16");
|
||||
assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
// i64, f32
|
||||
{
|
||||
let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f32");
|
||||
assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
// i64, f16
|
||||
{
|
||||
let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||
let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f16");
|
||||
assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
// i64, bf16
|
||||
{
|
||||
let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||
let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
|
||||
let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_bf16");
|
||||
assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user