diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b15505f7..775ee0fa 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -975,7 +975,6 @@ fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: .unwrap(); } - command_buffer.commit(); command_buffer.wait_until_completed(); @@ -984,7 +983,6 @@ fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: #[test] fn random() { - fn calc_mean(data: &[f32]) -> f32 { let sum = data.iter().sum::() as f32; let count = data.len(); @@ -997,10 +995,14 @@ fn random() { let count = data.len(); assert!(count > 0); - let variance = data.iter().map(|value| { - let diff = mean - (*value as f32); - diff * diff - }).sum::() / count as f32; + let variance = data + .iter() + .map(|value| { + let diff = mean - (*value as f32); + diff * diff + }) + .sum::() + / count as f32; variance.sqrt() } @@ -1017,11 +1019,29 @@ fn random() { macro_rules! validate_random { ($type:ty) => { - let results: Vec = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect(); + let results: Vec = run_random::<$type>( + concat!("rand_uniform_", stringify!($type)), + seed, + length, + min, + max, + ) + .into_iter() + .map(f32::from) + .collect(); results.iter().for_each(|v| assert!(*v >= min && *v <= max)); assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); - let results: Vec = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect(); + let results: Vec = run_random::<$type>( + concat!("rand_normal_", stringify!($type)), + seed, + length, + mean, + stddev, + ) + .into_iter() + .map(f32::from) + .collect(); assert!((calc_mean(&results) - mean).abs() < mean / 10.0); assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0); }; @@ -1030,4 +1050,4 @@ fn random() { validate_random!(f32); validate_random!(f16); validate_random!(bf16); -} \ No newline at end of file +}