This commit is contained in:
Ivar Flakstad
2024-01-12 07:26:42 +01:00
parent e63bb8661b
commit e06e8d0dbe

View File

@ -975,7 +975,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
@ -984,7 +983,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
#[test]
fn random() {
fn calc_mean(data: &[f32]) -> f32 {
let sum = data.iter().sum::<f32>() as f32;
let count = data.len();
@ -997,10 +995,14 @@ fn random() {
let count = data.len();
assert!(count > 0);
let variance = data.iter().map(|value| {
let diff = mean - (*value as f32);
diff * diff
}).sum::<f32>() / count as f32;
let variance = data
.iter()
.map(|value| {
let diff = mean - (*value as f32);
diff * diff
})
.sum::<f32>()
/ count as f32;
variance.sqrt()
}
@ -1017,11 +1019,29 @@ fn random() {
macro_rules! validate_random {
($type:ty) => {
let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect();
let results: Vec<f32> = run_random::<$type>(
concat!("rand_uniform_", stringify!($type)),
seed,
length,
min,
max,
)
.into_iter()
.map(f32::from)
.collect();
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect();
let results: Vec<f32> = run_random::<$type>(
concat!("rand_normal_", stringify!($type)),
seed,
length,
mean,
stddev,
)
.into_iter()
.map(f32::from)
.collect();
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
};
@ -1030,4 +1050,4 @@ fn random() {
validate_random!(f32);
validate_random!(f16);
validate_random!(bf16);
}
}