Put back affine strided tests

This commit is contained in:
Nicolas Patry
2023-11-30 11:40:39 +01:00
parent 4349ff1fc2
commit 2ca086939f

View File

@ -295,7 +295,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
output.read_to_vec::<T>(v.len())
}
fn _run_affine_strided<T: Clone>(
fn run_affine_strided<T: Clone>(
v: &[T],
shape: &[usize],
strides: &[usize],
@ -314,7 +314,7 @@ fn _run_affine_strided<T: Clone>(
&device,
command_buffer,
&kernels,
"affine_float",
"affine_float_strided",
shape,
&input,
strides,
@ -327,7 +327,8 @@ fn _run_affine_strided<T: Clone>(
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
let len: usize = shape.iter().product();
output.read_to_vec::<T>(len)
}
#[test]
@ -345,15 +346,17 @@ fn affine() {
assert_eq!(result, vec![2.6; 40_000]);
}
// #[test]
// fn affine_strided() {
// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
// let mul = 1.5;
// let add = 1.1;
// let result = run_affine_(&input, mul, add);
// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
// }
#[test]
fn affine_strided() {
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mul = 1.5;
let add = 1.1;
let shape = [4];
let strides = [2];
let result = run_affine_strided(&input, &shape, &strides, mul, add);
// 1 on 2
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
}
#[test]
fn index_select() {