mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Gaussian normal distribution of PRNG via Box-Muller transform
This commit is contained in:
@ -2,10 +2,14 @@ use candle_core::{DType, Device, Tensor};
|
|||||||
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
|
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
fn run(a: &Tensor) {
|
fn rand_uniform(a: &Tensor) {
|
||||||
a.rand_like(0.0, 1.0).unwrap();
|
a.rand_like(0.0, 1.0).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rand_normal(a: &Tensor) {
|
||||||
|
a.randn_like(100.0, 15.0).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
let b = 1;
|
let b = 1;
|
||||||
|
|
||||||
@ -13,18 +17,19 @@ fn criterion_benchmark(c: &mut Criterion) {
|
|||||||
let cols = 2048;
|
let cols = 2048;
|
||||||
|
|
||||||
let device = Device::new_metal(0).unwrap();
|
let device = Device::new_metal(0).unwrap();
|
||||||
|
let device2 = device.clone();
|
||||||
let dtype = DType::F32;
|
let dtype = DType::F32;
|
||||||
let tensor = Tensor::zeros((b, rows, cols), dtype, &device).unwrap();
|
let tensor = Tensor::zeros((b, rows, cols), dtype, &device).unwrap();
|
||||||
|
|
||||||
let flops = b * rows * cols;
|
let flops = b * rows * cols;
|
||||||
|
|
||||||
let mut group = c.benchmark_group("random_metal");
|
let mut group = c.benchmark_group("metal_random_uniform");
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
group.bench_function("iter", move |benches| {
|
group.bench_function("iter", move |benches| {
|
||||||
benches.iter_custom(|iters| {
|
benches.iter_custom(|iters| {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
for _i in 0..iters {
|
for _i in 0..iters {
|
||||||
run(black_box(&tensor));
|
rand_uniform(black_box(&tensor));
|
||||||
}
|
}
|
||||||
if let Device::Metal(device) = &device {
|
if let Device::Metal(device) = &device {
|
||||||
device.wait_until_completed().unwrap();
|
device.wait_until_completed().unwrap();
|
||||||
@ -35,6 +40,26 @@ fn criterion_benchmark(c: &mut Criterion) {
|
|||||||
})
|
})
|
||||||
});
|
});
|
||||||
group.finish();
|
group.finish();
|
||||||
|
|
||||||
|
let tensor = Tensor::zeros((b, rows, cols), dtype, &device2).unwrap();
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group("metal_random_normal");
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |benches| {
|
||||||
|
benches.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
rand_normal(black_box(&tensor));
|
||||||
|
}
|
||||||
|
if let Device::Metal(device) = &device2 {
|
||||||
|
device.wait_until_completed().unwrap();
|
||||||
|
} else {
|
||||||
|
panic!("Expected metal device");
|
||||||
|
}
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
}
|
}
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
criterion_group!(benches, criterion_benchmark);
|
||||||
|
@ -1385,7 +1385,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
compute_per_buffer,
|
compute_per_buffer,
|
||||||
buffers,
|
buffers,
|
||||||
kernels,
|
kernels,
|
||||||
seed
|
seed,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1467,8 +1467,9 @@ impl BackendDevice for MetalDevice {
|
|||||||
min as f32,
|
min as f32,
|
||||||
max as f32,
|
max as f32,
|
||||||
shape.elem_count(),
|
shape.elem_count(),
|
||||||
&buffer
|
&buffer,
|
||||||
).map_err(MetalError::from)?;
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||||
}
|
}
|
||||||
@ -1480,9 +1481,28 @@ impl BackendDevice for MetalDevice {
|
|||||||
mean: f64,
|
mean: f64,
|
||||||
stddev: f64,
|
stddev: f64,
|
||||||
) -> Result<Self::Storage> {
|
) -> Result<Self::Storage> {
|
||||||
// TODO is there a better way ?
|
let name = match dtype {
|
||||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
DType::F32 => "rand_normal_f32",
|
||||||
self.storage_from_cpu_storage(&cpu_storage)
|
DType::F16 => "rand_normal_f16",
|
||||||
|
DType::BF16 => "rand_normal_bf16",
|
||||||
|
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
||||||
|
};
|
||||||
|
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
|
||||||
|
let command_buffer = self.command_buffer()?;
|
||||||
|
candle_metal_kernels::call_random_normal(
|
||||||
|
&self.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.kernels,
|
||||||
|
name,
|
||||||
|
*self.seed.lock().unwrap(),
|
||||||
|
mean as f32,
|
||||||
|
stddev as f32,
|
||||||
|
shape.elem_count(),
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
|
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1415,7 +1415,6 @@ pub fn call_gemm(
|
|||||||
height: 1,
|
height: 1,
|
||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
// println!("grid size {grid_size:?} group size {group_size:?}");
|
|
||||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
@ -1588,39 +1587,11 @@ pub fn call_random_uniform(
|
|||||||
"min must be less than max".to_string(),
|
"min must be less than max".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let size: usize = match name {
|
|
||||||
"rand_uniform_f32" => 4,
|
|
||||||
"rand_uniform_f16" | "rand_uniform_bf16" => 2,
|
|
||||||
_ => Err(MetalKernelError::LoadLibraryError(format!(
|
|
||||||
"{name} is not a valid kernel for random"
|
|
||||||
)))?,
|
|
||||||
};
|
|
||||||
|
|
||||||
let elems_per_key = length;
|
|
||||||
let bytes_per_key = size * elems_per_key;
|
|
||||||
|
|
||||||
let out_per_key = (bytes_per_key + 4 - 1) / 4;
|
|
||||||
let half_size = out_per_key / 2;
|
|
||||||
let odd = length % 2 != 0;
|
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
let thread_group_count = MTLSize {
|
let odd = (length % 2 != 0) as usize;
|
||||||
width: length as u64,
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||||
height: half_size as u64 + odd as u64,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
let threads = std::cmp::min(
|
|
||||||
(half_size + odd as usize) as NSUInteger,
|
|
||||||
pipeline.max_total_threads_per_threadgroup(),
|
|
||||||
);
|
|
||||||
let thread_group_size = MTLSize {
|
|
||||||
width: threads,
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -1635,5 +1606,36 @@ pub fn call_random_uniform(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_random_normal(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
|
seed: u64,
|
||||||
|
mean: f32,
|
||||||
|
stddev: f32,
|
||||||
|
length: usize,
|
||||||
|
buffer: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
let odd = (length % 2 != 0) as usize;
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||||
|
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
set_params!(encoder, (length, seed, mean, stddev, buffer));
|
||||||
|
|
||||||
|
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
|
encoder.end_encoding();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
@ -33,29 +33,34 @@ static constexpr constant uint64_t PHI[16] = {
|
|||||||
// Combined Tausworthe and LCG Random Number Generator.
|
// Combined Tausworthe and LCG Random Number Generator.
|
||||||
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
|
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
|
||||||
// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
|
// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
|
||||||
class HybridTaus {
|
struct HybridTaus {
|
||||||
private:
|
|
||||||
thread float seed;
|
float state;
|
||||||
|
|
||||||
|
HybridTaus() thread = default;
|
||||||
|
HybridTaus() threadgroup = default;
|
||||||
|
HybridTaus() device = default;
|
||||||
|
HybridTaus() constant = default;
|
||||||
|
|
||||||
// Generate seeds for each thread.
|
// Generate seeds for each thread.
|
||||||
thread uint4 seed_per_thread(const ulong4 seeds) {
|
METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) {
|
||||||
return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
|
return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tausworthe generator.
|
// Tausworthe generator.
|
||||||
thread uint taus(const uint z, const int3 s, const uint M) {
|
METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) {
|
||||||
uint b = (((z << s.x) ^ z) >> s.y);
|
uint b = (((z << s.x) ^ z) >> s.y);
|
||||||
return (((z & M) << s.z) ^ b);
|
return (((z & M) << s.z) ^ b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// LCG generator.
|
// LCG generator.
|
||||||
thread uint lcg(const uint z) {
|
METAL_FUNC static uint lcg(const uint z) {
|
||||||
return (1664525 * z + 1013904223UL);
|
return (1664525 * z + 1013904223UL);
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
// Initialize the RNG state.
|
||||||
thread HybridTaus(const ulong4 seeds) {
|
METAL_FUNC static HybridTaus init(const ulong4 seeds) {
|
||||||
uint4 seed = this->seed_per_thread(seeds);
|
uint4 seed = seed_per_thread(seeds);
|
||||||
|
|
||||||
// Seed #1
|
// Seed #1
|
||||||
uint z1 = taus(seed.x, S1, 4294967294UL);
|
uint z1 = taus(seed.x, S1, 4294967294UL);
|
||||||
@ -84,52 +89,96 @@ public:
|
|||||||
z3 = taus(r1, S3, 429496280UL);
|
z3 = taus(r1, S3, 429496280UL);
|
||||||
z4 = lcg(r1);
|
z4 = lcg(r1);
|
||||||
|
|
||||||
this->seed = (z1^z2^z3^z4) * UNIF01_INV32;
|
HybridTaus rng;
|
||||||
|
rng.state = (z1^z2^z3^z4) * UNIF01_INV32;
|
||||||
|
return rng;
|
||||||
}
|
}
|
||||||
|
|
||||||
thread float rand() {
|
METAL_FUNC float rand() {
|
||||||
uint seed = this->seed * UNIF01_NORM32;
|
uint seed = this->state * UNIF01_NORM32;
|
||||||
uint z1 = taus(seed, S1, 429496729UL);
|
uint z1 = taus(seed, S1, 429496729UL);
|
||||||
uint z2 = taus(seed, S2, 4294967288UL);
|
uint z2 = taus(seed, S2, 4294967288UL);
|
||||||
uint z3 = taus(seed, S3, 429496280UL);
|
uint z3 = taus(seed, S3, 429496280UL);
|
||||||
uint z4 = lcg(seed);
|
uint z4 = lcg(seed);
|
||||||
|
|
||||||
thread float old_seed = this->seed;
|
thread float result = this->state;
|
||||||
this->seed = (z1^z2^z3^z4) * UNIF01_INV32;
|
this->state = (z1^z2^z3^z4) * UNIF01_INV32;
|
||||||
return old_seed;
|
return result;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T> METAL_FUNC void rand_uniform(
|
template<typename T> METAL_FUNC void rand_uniform(
|
||||||
constant size_t &elem_count,
|
constant size_t &size,
|
||||||
constant ulong &seed,
|
constant ulong &seed,
|
||||||
constant float &min,
|
constant float &min,
|
||||||
constant float &max,
|
constant float &max,
|
||||||
device T *out,
|
device T *out,
|
||||||
uint tid [[thread_position_in_grid]]
|
uint tid [[thread_position_in_grid]]
|
||||||
) {
|
) {
|
||||||
if (tid >= elem_count) {
|
if (tid >= size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
float diff = max - min;
|
float diff = max - min;
|
||||||
HybridTaus rng = HybridTaus({seed, tid, 1, 1});
|
HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
|
||||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
||||||
|
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Gaussian normal distribution using Box-Muller transform:
|
||||||
|
// https://en.wikipedia.org/wiki/Box–Muller_transform
|
||||||
|
template<typename T> METAL_FUNC void normal(
|
||||||
|
constant size_t &size,
|
||||||
|
constant ulong &seed,
|
||||||
|
constant float &mean,
|
||||||
|
constant float &stddev,
|
||||||
|
device T *out,
|
||||||
|
uint tid [[thread_position_in_grid]]
|
||||||
|
) {
|
||||||
|
if (tid >= size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
|
||||||
|
float u1 = rng.rand();
|
||||||
|
float u2 = rng.rand();
|
||||||
|
|
||||||
|
float cosval;
|
||||||
|
float sinval = sincos(u1 * TWO_PI, cosval);
|
||||||
|
float mag = stddev * sqrt(-2.0 * log(u1));
|
||||||
|
float z0 = mag * cosval + mean;
|
||||||
|
float z1 = mag * sinval + mean;
|
||||||
|
|
||||||
|
out[tid] = static_cast<T>(z0);
|
||||||
|
out[size - tid] = static_cast<T>(z1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define UNIFORM_OP(NAME, T) \
|
#define UNIFORM_OP(NAME, T) \
|
||||||
kernel void rand_uniform_##NAME( \
|
kernel void rand_uniform_##NAME( \
|
||||||
constant size_t &elem_count, \
|
constant size_t &size, \
|
||||||
constant ulong &seed, \
|
constant ulong &seed, \
|
||||||
constant float &min, \
|
constant float &min, \
|
||||||
constant float &max, \
|
constant float &max, \
|
||||||
device T *out, \
|
device T *out, \
|
||||||
uint tid [[thread_position_in_grid]] \
|
uint tid [[thread_position_in_grid]] \
|
||||||
) { \
|
) { \
|
||||||
rand_uniform<T>(elem_count, seed, min, max, out, tid); \
|
rand_uniform<T>(size, seed, min, max, out, tid); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#define NORMAL_OP(NAME, T) \
|
||||||
|
kernel void rand_normal_##NAME( \
|
||||||
|
constant size_t &size, \
|
||||||
|
constant ulong &seed, \
|
||||||
|
constant float &mean, \
|
||||||
|
constant float &stddev, \
|
||||||
|
device T *out, \
|
||||||
|
uint tid [[thread_position_in_grid]] \
|
||||||
|
) { \
|
||||||
|
normal<T>(size, seed, mean, stddev, out, tid); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
|
||||||
#define RANDOM_OPS(NAME, T) \
|
#define RANDOM_OPS(NAME, T) \
|
||||||
UNIFORM_OP(NAME, T) \
|
UNIFORM_OP(NAME, T) \
|
||||||
|
NORMAL_OP(NAME, T) \
|
||||||
|
|
||||||
RANDOM_OPS(f32, float)
|
RANDOM_OPS(f32, float)
|
||||||
RANDOM_OPS(f16, half)
|
RANDOM_OPS(f16, half)
|
||||||
|
@ -806,28 +806,43 @@ fn gemm() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_random<T: Clone>(seed: u64, shape: &[usize], name: &'static str, min: f32, max: f32) -> Vec<T> {
|
fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let fence = device.new_fence();
|
let fence = device.new_fence();
|
||||||
let kernels = Kernels::new(fence);
|
let kernels = Kernels::new(fence);
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let length = shape.iter().product::<usize>();
|
|
||||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||||
|
|
||||||
|
if name.starts_with("rand_uniform") {
|
||||||
call_random_uniform(
|
call_random_uniform(
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
name,
|
name,
|
||||||
seed,
|
seed,
|
||||||
min,
|
a,
|
||||||
max,
|
b,
|
||||||
length,
|
length,
|
||||||
&output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
} else {
|
||||||
|
call_random_normal(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
seed,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
length,
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
@ -837,24 +852,50 @@ fn run_random<T: Clone>(seed: u64, shape: &[usize], name: &'static str, min: f32
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn random() {
|
fn random() {
|
||||||
use std::fs::File;
|
|
||||||
use std::io::prelude::*;
|
|
||||||
|
|
||||||
let shape = vec![1024, 4];
|
fn calc_mean(data: &[f32]) -> f32 {
|
||||||
let seed = 299792458;
|
let sum = data.iter().sum::<f32>() as f32;
|
||||||
let min = -30.0;
|
let count = data.len();
|
||||||
let max = 30.0;
|
assert!(count > 0);
|
||||||
let results = run_random::<f32>(seed, &shape, "rand_uniform_f32", min, max);
|
sum / count as f32
|
||||||
for &v in &results {
|
|
||||||
assert!(v >= min && v <= max);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Writing bytes to file for testing with ENT
|
fn calc_stddev(data: &[f32]) -> f32 {
|
||||||
// https://www.fourmilab.ch/random/
|
let mean = calc_mean(data);
|
||||||
// TODO: Remove before merge
|
let count = data.len();
|
||||||
let (head, body, tail) = unsafe { results.align_to::<u8>() };
|
assert!(count > 0);
|
||||||
assert!(head.is_empty());
|
|
||||||
assert!(tail.is_empty());
|
let variance = data.iter().map(|value| {
|
||||||
let mut file = File::create("test").unwrap();
|
let diff = mean - (*value as f32);
|
||||||
file.write_all(body).unwrap();
|
diff * diff
|
||||||
|
}).sum::<f32>() / count as f32;
|
||||||
|
|
||||||
|
variance.sqrt()
|
||||||
|
}
|
||||||
|
|
||||||
|
let shape = vec![1024, 10];
|
||||||
|
|
||||||
|
let length = shape.iter().product::<usize>();
|
||||||
|
let seed = 299792458;
|
||||||
|
|
||||||
|
let min = -30.0;
|
||||||
|
let max = 30.0;
|
||||||
|
let mean = 100.0;
|
||||||
|
let stddev = 50.0;
|
||||||
|
|
||||||
|
macro_rules! validate_random {
|
||||||
|
($type:ty) => {
|
||||||
|
let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect();
|
||||||
|
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
|
||||||
|
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
||||||
|
|
||||||
|
let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect();
|
||||||
|
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
|
||||||
|
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
validate_random!(f32);
|
||||||
|
validate_random!(f16);
|
||||||
|
validate_random!(bf16);
|
||||||
}
|
}
|
Reference in New Issue
Block a user