From 6bf52b9fdf82ad775611e82924d73172660a605e Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 5 Jan 2024 21:18:12 +0100 Subject: [PATCH] Gaussian normal distribution of PRNG via Box-Muller transform --- candle-core/benches/random.rs | 31 +++++++- candle-core/src/metal_backend.rs | 32 ++++++-- candle-metal-kernels/src/lib.rs | 64 +++++++-------- candle-metal-kernels/src/random.metal | 107 +++++++++++++++++++------- candle-metal-kernels/src/tests.rs | 105 +++++++++++++++++-------- 5 files changed, 238 insertions(+), 101 deletions(-) diff --git a/candle-core/benches/random.rs b/candle-core/benches/random.rs index ce42fd4e..781d8b39 100644 --- a/candle-core/benches/random.rs +++ b/candle-core/benches/random.rs @@ -2,10 +2,14 @@ use candle_core::{DType, Device, Tensor}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use std::time::Instant; -fn run(a: &Tensor) { +fn rand_uniform(a: &Tensor) { a.rand_like(0.0, 1.0).unwrap(); } +fn rand_normal(a: &Tensor) { + a.randn_like(100.0, 15.0).unwrap(); +} + fn criterion_benchmark(c: &mut Criterion) { let b = 1; @@ -13,18 +17,19 @@ fn criterion_benchmark(c: &mut Criterion) { let cols = 2048; let device = Device::new_metal(0).unwrap(); + let device2 = device.clone(); let dtype = DType::F32; let tensor = Tensor::zeros((b, rows, cols), dtype, &device).unwrap(); let flops = b * rows * cols; - let mut group = c.benchmark_group("random_metal"); + let mut group = c.benchmark_group("metal_random_uniform"); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |benches| { benches.iter_custom(|iters| { let start = Instant::now(); for _i in 0..iters { - run(black_box(&tensor)); + rand_uniform(black_box(&tensor)); } if let Device::Metal(device) = &device { device.wait_until_completed().unwrap(); @@ -35,6 +40,26 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); group.finish(); + + let tensor = Tensor::zeros((b, rows, cols), dtype, &device2).unwrap(); + + let mut group = c.benchmark_group("metal_random_normal"); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |benches| { + benches.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + rand_normal(black_box(&tensor)); + } + if let Device::Metal(device) = &device2 { + device.wait_until_completed().unwrap(); + } else { + panic!("Expected metal device"); + } + start.elapsed() + }) + }); + group.finish(); } criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 059cf24b..73a532e6 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1385,7 +1385,7 @@ impl BackendDevice for MetalDevice { compute_per_buffer, buffers, kernels, - seed + seed, }) } @@ -1467,8 +1467,9 @@ impl BackendDevice for MetalDevice { min as f32, max as f32, shape.elem_count(), - &buffer - ).map_err(MetalError::from)?; + &buffer, + ) + .map_err(MetalError::from)?; Ok(Self::Storage::new(buffer, self.clone(), dtype)) } @@ -1480,9 +1481,28 @@ impl BackendDevice for MetalDevice { mean: f64, stddev: f64, ) -> Result { - // TODO is there a better way ? - let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?; - self.storage_from_cpu_storage(&cpu_storage) + let name = match dtype { + DType::F32 => "rand_normal_f32", + DType::F16 => "rand_normal_f16", + DType::BF16 => "rand_normal_bf16", + dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"), + }; + let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?; + let command_buffer = self.command_buffer()?; + candle_metal_kernels::call_random_normal( + &self.device, + &command_buffer, + &self.kernels, + name, + *self.seed.lock().unwrap(), + mean as f32, + stddev as f32, + shape.elem_count(), + &buffer, + ) + .map_err(MetalError::from)?; + + Ok(Self::Storage::new(buffer, self.clone(), dtype)) } } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 04442c8a..e2603b3b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1415,7 +1415,6 @@ pub fn call_gemm( height: 1, depth: 1, }; - // println!("grid size {grid_size:?} group size {group_size:?}"); encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); @@ -1588,39 +1587,11 @@ pub fn call_random_uniform( "min must be less than max".to_string(), )); } - - let size: usize = match name { - "rand_uniform_f32" => 4, - "rand_uniform_f16" | "rand_uniform_bf16" => 2, - _ => Err(MetalKernelError::LoadLibraryError(format!( - "{name} is not a valid kernel for random" - )))?, - }; - - let elems_per_key = length; - let bytes_per_key = size * elems_per_key; - - let out_per_key = (bytes_per_key + 4 - 1) / 4; - let half_size = out_per_key / 2; - let odd = length % 2 != 0; - let pipeline = kernels.load_pipeline(device, Source::Random, name)?; let encoder = command_buffer.new_compute_command_encoder(); - let thread_group_count = MTLSize { - width: length as u64, - height: half_size as u64 + odd as u64, - depth: 1, - }; - let threads = std::cmp::min( - (half_size + odd as usize) as NSUInteger, - pipeline.max_total_threads_per_threadgroup(), - ); - let thread_group_size = MTLSize { - width: threads, - height: 1, - depth: 1, - }; + let odd = (length % 2 != 0) as usize; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); @@ -1635,5 +1606,36 @@ pub fn call_random_uniform( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_random_normal( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + seed: u64, + mean: f32, + stddev: f32, + length: usize, + buffer: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Random, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + + let odd = (length % 2 != 0) as usize; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, seed, mean, stddev, buffer)); + + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal index 1604123d..5369e8e2 100644 --- a/candle-metal-kernels/src/random.metal +++ b/candle-metal-kernels/src/random.metal @@ -33,29 +33,34 @@ static constexpr constant uint64_t PHI[16] = { // Combined Tausworthe and LCG Random Number Generator. // https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application // https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf -class HybridTaus { -private: - thread float seed; +struct HybridTaus { + + float state; + + HybridTaus() thread = default; + HybridTaus() threadgroup = default; + HybridTaus() device = default; + HybridTaus() constant = default; // Generate seeds for each thread. - thread uint4 seed_per_thread(const ulong4 seeds) { + METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) { return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL)); } // Tausworthe generator. - thread uint taus(const uint z, const int3 s, const uint M) { + METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) { uint b = (((z << s.x) ^ z) >> s.y); return (((z & M) << s.z) ^ b); } // LCG generator. - thread uint lcg(const uint z) { + METAL_FUNC static uint lcg(const uint z) { return (1664525 * z + 1013904223UL); } -public: - thread HybridTaus(const ulong4 seeds) { - uint4 seed = this->seed_per_thread(seeds); + // Initialize the RNG state. + METAL_FUNC static HybridTaus init(const ulong4 seeds) { + uint4 seed = seed_per_thread(seeds); // Seed #1 uint z1 = taus(seed.x, S1, 4294967294UL); @@ -84,52 +89,96 @@ public: z3 = taus(r1, S3, 429496280UL); z4 = lcg(r1); - this->seed = (z1^z2^z3^z4) * UNIF01_INV32; + HybridTaus rng; + rng.state = (z1^z2^z3^z4) * UNIF01_INV32; + return rng; } - thread float rand() { - uint seed = this->seed * UNIF01_NORM32; + METAL_FUNC float rand() { + uint seed = this->state * UNIF01_NORM32; uint z1 = taus(seed, S1, 429496729UL); uint z2 = taus(seed, S2, 4294967288UL); uint z3 = taus(seed, S3, 429496280UL); uint z4 = lcg(seed); - thread float old_seed = this->seed; - this->seed = (z1^z2^z3^z4) * UNIF01_INV32; - return old_seed; + thread float result = this->state; + this->state = (z1^z2^z3^z4) * UNIF01_INV32; + return result; } }; template METAL_FUNC void rand_uniform( - constant size_t &elem_count, + constant size_t &size, constant ulong &seed, constant float &min, constant float &max, device T *out, uint tid [[thread_position_in_grid]] ) { - if (tid >= elem_count) { + if (tid >= size) { return; } float diff = max - min; - HybridTaus rng = HybridTaus({seed, tid, 1, 1}); + HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); out[tid] = static_cast(rng.rand() * diff + min); + out[size - tid] = static_cast(rng.rand() * diff + min); } -#define UNIFORM_OP(NAME, T) \ -kernel void rand_uniform_##NAME( \ - constant size_t &elem_count, \ - constant ulong &seed, \ - constant float &min, \ - constant float &max, \ - device T *out, \ - uint tid [[thread_position_in_grid]] \ -) { \ - rand_uniform(elem_count, seed, min, max, out, tid); \ -} \ +// Create Gaussian normal distribution using Box-Muller transform: +// https://en.wikipedia.org/wiki/Box–Muller_transform +template METAL_FUNC void normal( + constant size_t &size, + constant ulong &seed, + constant float &mean, + constant float &stddev, + device T *out, + uint tid [[thread_position_in_grid]] +) { + if (tid >= size) { + return; + } + HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); + float u1 = rng.rand(); + float u2 = rng.rand(); + + float cosval; + float sinval = sincos(u1 * TWO_PI, cosval); + float mag = stddev * sqrt(-2.0 * log(u1)); + float z0 = mag * cosval + mean; + float z1 = mag * sinval + mean; + + out[tid] = static_cast(z0); + out[size - tid] = static_cast(z1); +} + +#define UNIFORM_OP(NAME, T) \ +kernel void rand_uniform_##NAME( \ + constant size_t &size, \ + constant ulong &seed, \ + constant float &min, \ + constant float &max, \ + device T *out, \ + uint tid [[thread_position_in_grid]] \ +) { \ + rand_uniform(size, seed, min, max, out, tid); \ +} \ + +#define NORMAL_OP(NAME, T) \ +kernel void rand_normal_##NAME( \ + constant size_t &size, \ + constant ulong &seed, \ + constant float &mean, \ + constant float &stddev, \ + device T *out, \ + uint tid [[thread_position_in_grid]] \ +) { \ + normal(size, seed, mean, stddev, out, tid); \ +} \ + #define RANDOM_OPS(NAME, T) \ UNIFORM_OP(NAME, T) \ +NORMAL_OP(NAME, T) \ RANDOM_OPS(f32, float) RANDOM_OPS(f16, half) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index d0ca8330..067dece8 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -806,28 +806,43 @@ fn gemm() { ); } -fn run_random(seed: u64, shape: &[usize], name: &'static str, min: f32, max: f32) -> Vec { +fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec { let device = device(); let fence = device.new_fence(); let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; - let length = shape.iter().product::(); let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); - call_random_uniform( - &device, - command_buffer, - &kernels, - name, - seed, - min, - max, - length, - &output, - ) - .unwrap(); + if name.starts_with("rand_uniform") { + call_random_uniform( + &device, + command_buffer, + &kernels, + name, + seed, + a, + b, + length, + &output, + ) + .unwrap(); + } else { + call_random_normal( + &device, + command_buffer, + &kernels, + name, + seed, + a, + b, + length, + &output, + ) + .unwrap(); + } + command_buffer.commit(); command_buffer.wait_until_completed(); @@ -837,24 +852,50 @@ fn run_random(seed: u64, shape: &[usize], name: &'static str, min: f32 #[test] fn random() { - use std::fs::File; - use std::io::prelude::*; - let shape = vec![1024, 4]; - let seed = 299792458; - let min = -30.0; - let max = 30.0; - let results = run_random::(seed, &shape, "rand_uniform_f32", min, max); - for &v in &results { - assert!(v >= min && v <= max); + fn calc_mean(data: &[f32]) -> f32 { + let sum = data.iter().sum::() as f32; + let count = data.len(); + assert!(count > 0); + sum / count as f32 } - // Writing bytes to file for testing with ENT - // https://www.fourmilab.ch/random/ - // TODO: Remove before merge - let (head, body, tail) = unsafe { results.align_to::() }; - assert!(head.is_empty()); - assert!(tail.is_empty()); - let mut file = File::create("test").unwrap(); - file.write_all(body).unwrap(); -} + fn calc_stddev(data: &[f32]) -> f32 { + let mean = calc_mean(data); + let count = data.len(); + assert!(count > 0); + + let variance = data.iter().map(|value| { + let diff = mean - (*value as f32); + diff * diff + }).sum::() / count as f32; + + variance.sqrt() + } + + let shape = vec![1024, 10]; + + let length = shape.iter().product::(); + let seed = 299792458; + + let min = -30.0; + let max = 30.0; + let mean = 100.0; + let stddev = 50.0; + + macro_rules! validate_random { + ($type:ty) => { + let results: Vec = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect(); + results.iter().for_each(|v| assert!(*v >= min && *v <= max)); + assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); + + let results: Vec = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect(); + assert!((calc_mean(&results) - mean).abs() < mean / 10.0); + assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0); + }; + } + + validate_random!(f32); + validate_random!(f16); + validate_random!(bf16); +} \ No newline at end of file