From 2ca086939f91f5d8ccec745e47648f74fa520988 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 30 Nov 2023 11:40:39 +0100 Subject: [PATCH] Put back affine strided tests --- candle-metal-kernels/src/tests.rs | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 66dc8d01..59f54fa9 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -295,7 +295,7 @@ fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { output.read_to_vec::(v.len()) } -fn _run_affine_strided( +fn run_affine_strided( v: &[T], shape: &[usize], strides: &[usize], @@ -314,7 +314,7 @@ fn _run_affine_strided( &device, command_buffer, &kernels, - "affine_float", + "affine_float_strided", shape, &input, strides, @@ -327,7 +327,8 @@ fn _run_affine_strided( command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) + let len: usize = shape.iter().product(); + output.read_to_vec::(len) } #[test] @@ -345,15 +346,17 @@ fn affine() { assert_eq!(result, vec![2.6; 40_000]); } -// #[test] -// fn affine_strided() { -// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; -// let mul = 1.5; -// let add = 1.1; -// let result = run_affine_(&input, mul, add); -// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); - -// } +#[test] +fn affine_strided() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let shape = [4]; + let strides = [2]; + let result = run_affine_strided(&input, &shape, &strides, mul, add); + // 1 on 2 + assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]); +} #[test] fn index_select() {