Put back affine strided tests

This commit is contained in:
Nicolas Patry
2023-11-30 11:40:39 +01:00
parent 4349ff1fc2
commit 2ca086939f

View File

@ -295,7 +295,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
output.read_to_vec::<T>(v.len()) output.read_to_vec::<T>(v.len())
} }
fn _run_affine_strided<T: Clone>( fn run_affine_strided<T: Clone>(
v: &[T], v: &[T],
shape: &[usize], shape: &[usize],
strides: &[usize], strides: &[usize],
@ -314,7 +314,7 @@ fn _run_affine_strided<T: Clone>(
&device, &device,
command_buffer, command_buffer,
&kernels, &kernels,
"affine_float", "affine_float_strided",
shape, shape,
&input, &input,
strides, strides,
@ -327,7 +327,8 @@ fn _run_affine_strided<T: Clone>(
command_buffer.commit(); command_buffer.commit();
command_buffer.wait_until_completed(); command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len()) let len: usize = shape.iter().product();
output.read_to_vec::<T>(len)
} }
#[test] #[test]
@ -345,15 +346,17 @@ fn affine() {
assert_eq!(result, vec![2.6; 40_000]); assert_eq!(result, vec![2.6; 40_000]);
} }
// #[test] #[test]
// fn affine_strided() { fn affine_strided() {
// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
// let mul = 1.5; let mul = 1.5;
// let add = 1.1; let add = 1.1;
// let result = run_affine_(&input, mul, add); let shape = [4];
// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); let strides = [2];
let result = run_affine_strided(&input, &shape, &strides, mul, add);
// } // 1 on 2
assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
}
#[test] #[test]
fn index_select() { fn index_select() {