mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
fmt
This commit is contained in:
@ -975,7 +975,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
@ -984,7 +983,6 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn random() {
|
fn random() {
|
||||||
|
|
||||||
fn calc_mean(data: &[f32]) -> f32 {
|
fn calc_mean(data: &[f32]) -> f32 {
|
||||||
let sum = data.iter().sum::<f32>() as f32;
|
let sum = data.iter().sum::<f32>() as f32;
|
||||||
let count = data.len();
|
let count = data.len();
|
||||||
@ -997,10 +995,14 @@ fn random() {
|
|||||||
let count = data.len();
|
let count = data.len();
|
||||||
assert!(count > 0);
|
assert!(count > 0);
|
||||||
|
|
||||||
let variance = data.iter().map(|value| {
|
let variance = data
|
||||||
|
.iter()
|
||||||
|
.map(|value| {
|
||||||
let diff = mean - (*value as f32);
|
let diff = mean - (*value as f32);
|
||||||
diff * diff
|
diff * diff
|
||||||
}).sum::<f32>() / count as f32;
|
})
|
||||||
|
.sum::<f32>()
|
||||||
|
/ count as f32;
|
||||||
|
|
||||||
variance.sqrt()
|
variance.sqrt()
|
||||||
}
|
}
|
||||||
@ -1017,11 +1019,29 @@ fn random() {
|
|||||||
|
|
||||||
macro_rules! validate_random {
|
macro_rules! validate_random {
|
||||||
($type:ty) => {
|
($type:ty) => {
|
||||||
let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect();
|
let results: Vec<f32> = run_random::<$type>(
|
||||||
|
concat!("rand_uniform_", stringify!($type)),
|
||||||
|
seed,
|
||||||
|
length,
|
||||||
|
min,
|
||||||
|
max,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.map(f32::from)
|
||||||
|
.collect();
|
||||||
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
|
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
|
||||||
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
||||||
|
|
||||||
let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect();
|
let results: Vec<f32> = run_random::<$type>(
|
||||||
|
concat!("rand_normal_", stringify!($type)),
|
||||||
|
seed,
|
||||||
|
length,
|
||||||
|
mean,
|
||||||
|
stddev,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.map(f32::from)
|
||||||
|
.collect();
|
||||||
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
|
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
|
||||||
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
|
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user