Gaussian normal distribution of PRNG via Box-Muller transform

This commit is contained in:
Ivar Flakstad
2024-01-05 21:18:12 +01:00
parent 955e63c803
commit 6bf52b9fdf
5 changed files with 238 additions and 101 deletions

View File

@ -2,10 +2,14 @@ use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
use std::time::Instant; use std::time::Instant;
fn run(a: &Tensor) { fn rand_uniform(a: &Tensor) {
a.rand_like(0.0, 1.0).unwrap(); a.rand_like(0.0, 1.0).unwrap();
} }
fn rand_normal(a: &Tensor) {
a.randn_like(100.0, 15.0).unwrap();
}
fn criterion_benchmark(c: &mut Criterion) { fn criterion_benchmark(c: &mut Criterion) {
let b = 1; let b = 1;
@ -13,18 +17,19 @@ fn criterion_benchmark(c: &mut Criterion) {
let cols = 2048; let cols = 2048;
let device = Device::new_metal(0).unwrap(); let device = Device::new_metal(0).unwrap();
let device2 = device.clone();
let dtype = DType::F32; let dtype = DType::F32;
let tensor = Tensor::zeros((b, rows, cols), dtype, &device).unwrap(); let tensor = Tensor::zeros((b, rows, cols), dtype, &device).unwrap();
let flops = b * rows * cols; let flops = b * rows * cols;
let mut group = c.benchmark_group("random_metal"); let mut group = c.benchmark_group("metal_random_uniform");
group.throughput(Throughput::Bytes(flops as u64)); group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| { group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| { benches.iter_custom(|iters| {
let start = Instant::now(); let start = Instant::now();
for _i in 0..iters { for _i in 0..iters {
run(black_box(&tensor)); rand_uniform(black_box(&tensor));
} }
if let Device::Metal(device) = &device { if let Device::Metal(device) = &device {
device.wait_until_completed().unwrap(); device.wait_until_completed().unwrap();
@ -35,6 +40,26 @@ fn criterion_benchmark(c: &mut Criterion) {
}) })
}); });
group.finish(); group.finish();
let tensor = Tensor::zeros((b, rows, cols), dtype, &device2).unwrap();
let mut group = c.benchmark_group("metal_random_normal");
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_normal(black_box(&tensor));
}
if let Device::Metal(device) = &device2 {
device.wait_until_completed().unwrap();
} else {
panic!("Expected metal device");
}
start.elapsed()
})
});
group.finish();
} }
criterion_group!(benches, criterion_benchmark); criterion_group!(benches, criterion_benchmark);

View File

@ -1385,7 +1385,7 @@ impl BackendDevice for MetalDevice {
compute_per_buffer, compute_per_buffer,
buffers, buffers,
kernels, kernels,
seed seed,
}) })
} }
@ -1467,8 +1467,9 @@ impl BackendDevice for MetalDevice {
min as f32, min as f32,
max as f32, max as f32,
shape.elem_count(), shape.elem_count(),
&buffer &buffer,
).map_err(MetalError::from)?; )
.map_err(MetalError::from)?;
Ok(Self::Storage::new(buffer, self.clone(), dtype)) Ok(Self::Storage::new(buffer, self.clone(), dtype))
} }
@ -1480,9 +1481,28 @@ impl BackendDevice for MetalDevice {
mean: f64, mean: f64,
stddev: f64, stddev: f64,
) -> Result<Self::Storage> { ) -> Result<Self::Storage> {
// TODO is there a better way ? let name = match dtype {
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?; DType::F32 => "rand_normal_f32",
self.storage_from_cpu_storage(&cpu_storage) DType::F16 => "rand_normal_f16",
DType::BF16 => "rand_normal_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_normal(
&self.device,
&command_buffer,
&self.kernels,
name,
*self.seed.lock().unwrap(),
mean as f32,
stddev as f32,
shape.elem_count(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(buffer, self.clone(), dtype))
} }
} }

View File

@ -1415,7 +1415,6 @@ pub fn call_gemm(
height: 1, height: 1,
depth: 1, depth: 1,
}; };
// println!("grid size {grid_size:?} group size {group_size:?}");
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
@ -1588,39 +1587,11 @@ pub fn call_random_uniform(
"min must be less than max".to_string(), "min must be less than max".to_string(),
)); ));
} }
let size: usize = match name {
"rand_uniform_f32" => 4,
"rand_uniform_f16" | "rand_uniform_bf16" => 2,
_ => Err(MetalKernelError::LoadLibraryError(format!(
"{name} is not a valid kernel for random"
)))?,
};
let elems_per_key = length;
let bytes_per_key = size * elems_per_key;
let out_per_key = (bytes_per_key + 4 - 1) / 4;
let half_size = out_per_key / 2;
let odd = length % 2 != 0;
let pipeline = kernels.load_pipeline(device, Source::Random, name)?; let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let thread_group_count = MTLSize { let odd = (length % 2 != 0) as usize;
width: length as u64, let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
height: half_size as u64 + odd as u64,
depth: 1,
};
let threads = std::cmp::min(
(half_size + odd as usize) as NSUInteger,
pipeline.max_total_threads_per_threadgroup(),
);
let thread_group_size = MTLSize {
width: threads,
height: 1,
depth: 1,
};
encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -1635,5 +1606,36 @@ pub fn call_random_uniform(
Ok(()) Ok(())
} }
#[allow(clippy::too_many_arguments)]
pub fn call_random_normal(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
seed: u64,
mean: f32,
stddev: f32,
length: usize,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, seed, mean, stddev, buffer));
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;

View File

@ -33,29 +33,34 @@ static constexpr constant uint64_t PHI[16] = {
// Combined Tausworthe and LCG Random Number Generator. // Combined Tausworthe and LCG Random Number Generator.
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application // https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf // https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
class HybridTaus { struct HybridTaus {
private:
thread float seed; float state;
HybridTaus() thread = default;
HybridTaus() threadgroup = default;
HybridTaus() device = default;
HybridTaus() constant = default;
// Generate seeds for each thread. // Generate seeds for each thread.
thread uint4 seed_per_thread(const ulong4 seeds) { METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) {
return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL)); return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
} }
// Tausworthe generator. // Tausworthe generator.
thread uint taus(const uint z, const int3 s, const uint M) { METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) {
uint b = (((z << s.x) ^ z) >> s.y); uint b = (((z << s.x) ^ z) >> s.y);
return (((z & M) << s.z) ^ b); return (((z & M) << s.z) ^ b);
} }
// LCG generator. // LCG generator.
thread uint lcg(const uint z) { METAL_FUNC static uint lcg(const uint z) {
return (1664525 * z + 1013904223UL); return (1664525 * z + 1013904223UL);
} }
public: // Initialize the RNG state.
thread HybridTaus(const ulong4 seeds) { METAL_FUNC static HybridTaus init(const ulong4 seeds) {
uint4 seed = this->seed_per_thread(seeds); uint4 seed = seed_per_thread(seeds);
// Seed #1 // Seed #1
uint z1 = taus(seed.x, S1, 4294967294UL); uint z1 = taus(seed.x, S1, 4294967294UL);
@ -84,52 +89,96 @@ public:
z3 = taus(r1, S3, 429496280UL); z3 = taus(r1, S3, 429496280UL);
z4 = lcg(r1); z4 = lcg(r1);
this->seed = (z1^z2^z3^z4) * UNIF01_INV32; HybridTaus rng;
rng.state = (z1^z2^z3^z4) * UNIF01_INV32;
return rng;
} }
thread float rand() { METAL_FUNC float rand() {
uint seed = this->seed * UNIF01_NORM32; uint seed = this->state * UNIF01_NORM32;
uint z1 = taus(seed, S1, 429496729UL); uint z1 = taus(seed, S1, 429496729UL);
uint z2 = taus(seed, S2, 4294967288UL); uint z2 = taus(seed, S2, 4294967288UL);
uint z3 = taus(seed, S3, 429496280UL); uint z3 = taus(seed, S3, 429496280UL);
uint z4 = lcg(seed); uint z4 = lcg(seed);
thread float old_seed = this->seed; thread float result = this->state;
this->seed = (z1^z2^z3^z4) * UNIF01_INV32; this->state = (z1^z2^z3^z4) * UNIF01_INV32;
return old_seed; return result;
} }
}; };
template<typename T> METAL_FUNC void rand_uniform( template<typename T> METAL_FUNC void rand_uniform(
constant size_t &elem_count, constant size_t &size,
constant ulong &seed, constant ulong &seed,
constant float &min, constant float &min,
constant float &max, constant float &max,
device T *out, device T *out,
uint tid [[thread_position_in_grid]] uint tid [[thread_position_in_grid]]
) { ) {
if (tid >= elem_count) { if (tid >= size) {
return; return;
} }
float diff = max - min; float diff = max - min;
HybridTaus rng = HybridTaus({seed, tid, 1, 1}); HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
out[tid] = static_cast<T>(rng.rand() * diff + min); out[tid] = static_cast<T>(rng.rand() * diff + min);
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
}
// Create Gaussian normal distribution using Box-Muller transform:
// https://en.wikipedia.org/wiki/BoxMuller_transform
template<typename T> METAL_FUNC void normal(
constant size_t &size,
constant ulong &seed,
constant float &mean,
constant float &stddev,
device T *out,
uint tid [[thread_position_in_grid]]
) {
if (tid >= size) {
return;
}
HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
float u1 = rng.rand();
float u2 = rng.rand();
float cosval;
float sinval = sincos(u1 * TWO_PI, cosval);
float mag = stddev * sqrt(-2.0 * log(u1));
float z0 = mag * cosval + mean;
float z1 = mag * sinval + mean;
out[tid] = static_cast<T>(z0);
out[size - tid] = static_cast<T>(z1);
} }
#define UNIFORM_OP(NAME, T) \ #define UNIFORM_OP(NAME, T) \
kernel void rand_uniform_##NAME( \ kernel void rand_uniform_##NAME( \
constant size_t &elem_count, \ constant size_t &size, \
constant ulong &seed, \ constant ulong &seed, \
constant float &min, \ constant float &min, \
constant float &max, \ constant float &max, \
device T *out, \ device T *out, \
uint tid [[thread_position_in_grid]] \ uint tid [[thread_position_in_grid]] \
) { \ ) { \
rand_uniform<T>(elem_count, seed, min, max, out, tid); \ rand_uniform<T>(size, seed, min, max, out, tid); \
} \ } \
#define NORMAL_OP(NAME, T) \
kernel void rand_normal_##NAME( \
constant size_t &size, \
constant ulong &seed, \
constant float &mean, \
constant float &stddev, \
device T *out, \
uint tid [[thread_position_in_grid]] \
) { \
normal<T>(size, seed, mean, stddev, out, tid); \
} \
#define RANDOM_OPS(NAME, T) \ #define RANDOM_OPS(NAME, T) \
UNIFORM_OP(NAME, T) \ UNIFORM_OP(NAME, T) \
NORMAL_OP(NAME, T) \
RANDOM_OPS(f32, float) RANDOM_OPS(f32, float)
RANDOM_OPS(f16, half) RANDOM_OPS(f16, half)

View File

@ -806,28 +806,43 @@ fn gemm() {
); );
} }
fn run_random<T: Clone>(seed: u64, shape: &[usize], name: &'static str, min: f32, max: f32) -> Vec<T> { fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device(); let device = device();
let fence = device.new_fence(); let fence = device.new_fence();
let kernels = Kernels::new(fence); let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let length = shape.iter().product::<usize>();
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
if name.starts_with("rand_uniform") {
call_random_uniform( call_random_uniform(
&device, &device,
command_buffer, command_buffer,
&kernels, &kernels,
name, name,
seed, seed,
min, a,
max, b,
length, length,
&output, &output,
) )
.unwrap(); .unwrap();
} else {
call_random_normal(
&device,
command_buffer,
&kernels,
name,
seed,
a,
b,
length,
&output,
)
.unwrap();
}
command_buffer.commit(); command_buffer.commit();
command_buffer.wait_until_completed(); command_buffer.wait_until_completed();
@ -837,24 +852,50 @@ fn run_random<T: Clone>(seed: u64, shape: &[usize], name: &'static str, min: f32
#[test] #[test]
fn random() { fn random() {
use std::fs::File;
use std::io::prelude::*;
let shape = vec![1024, 4]; fn calc_mean(data: &[f32]) -> f32 {
let seed = 299792458; let sum = data.iter().sum::<f32>() as f32;
let min = -30.0; let count = data.len();
let max = 30.0; assert!(count > 0);
let results = run_random::<f32>(seed, &shape, "rand_uniform_f32", min, max); sum / count as f32
for &v in &results {
assert!(v >= min && v <= max);
} }
// Writing bytes to file for testing with ENT fn calc_stddev(data: &[f32]) -> f32 {
// https://www.fourmilab.ch/random/ let mean = calc_mean(data);
// TODO: Remove before merge let count = data.len();
let (head, body, tail) = unsafe { results.align_to::<u8>() }; assert!(count > 0);
assert!(head.is_empty());
assert!(tail.is_empty()); let variance = data.iter().map(|value| {
let mut file = File::create("test").unwrap(); let diff = mean - (*value as f32);
file.write_all(body).unwrap(); diff * diff
}).sum::<f32>() / count as f32;
variance.sqrt()
}
let shape = vec![1024, 10];
let length = shape.iter().product::<usize>();
let seed = 299792458;
let min = -30.0;
let max = 30.0;
let mean = 100.0;
let stddev = 50.0;
macro_rules! validate_random {
($type:ty) => {
let results: Vec<f32> = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect();
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
let results: Vec<f32> = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect();
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
};
}
validate_random!(f32);
validate_random!(f16);
validate_random!(bf16);
} }