mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
fmt
This commit is contained in:
@ -975,7 +975,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
@ -984,7 +983,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
|
||||
|
||||
#[test]
|
||||
fn random() {
|
||||
|
||||
fn calc_mean(data: &[f32]) -> f32 {
|
||||
let sum = data.iter().sum::<f32>() as f32;
|
||||
let count = data.len();
|
||||
@ -997,10 +995,14 @@ fn random() {
|
||||
let count = data.len();
|
||||
assert!(count > 0);
|
||||
|
||||
let variance = data.iter().map(|value| {
|
||||
let diff = mean - (*value as f32);
|
||||
diff * diff
|
||||
}).sum::<f32>() / count as f32;
|
||||
let variance = data
|
||||
.iter()
|
||||
.map(|value| {
|
||||
let diff = mean - (*value as f32);
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
/ count as f32;
|
||||
|
||||
variance.sqrt()
|
||||
}
|
||||
@ -1017,11 +1019,29 @@ fn random() {
|
||||
|
||||
macro_rules! validate_random {
|
||||
($type:ty) => {
|
||||
let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect();
|
||||
let results: Vec<f32> = run_random::<$type>(
|
||||
concat!("rand_uniform_", stringify!($type)),
|
||||
seed,
|
||||
length,
|
||||
min,
|
||||
max,
|
||||
)
|
||||
.into_iter()
|
||||
.map(f32::from)
|
||||
.collect();
|
||||
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
|
||||
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
||||
|
||||
let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect();
|
||||
let results: Vec<f32> = run_random::<$type>(
|
||||
concat!("rand_normal_", stringify!($type)),
|
||||
seed,
|
||||
length,
|
||||
mean,
|
||||
stddev,
|
||||
)
|
||||
.into_iter()
|
||||
.map(f32::from)
|
||||
.collect();
|
||||
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
|
||||
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
|
||||
};
|
||||
@ -1030,4 +1050,4 @@ fn random() {
|
||||
validate_random!(f32);
|
||||
validate_random!(f16);
|
||||
validate_random!(bf16);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user