diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 8a75bd7c..673e6e11 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,9 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; +use cudarc::driver::DeviceRepr; use metal; -use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use metal::{ + Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger, +}; use std::collections::HashMap; +use std::ffi::c_void; use std::path::Path; use std::sync::{Arc, Mutex, RwLock, TryLockError}; @@ -107,7 +111,7 @@ pub struct MetalDevice { /// (strong_count = 1). buffers: AllocatedBuffers, /// Seed for random number generation. - seed: Arc>, + seed: Arc>, } impl std::fmt::Debug for MetalDevice { @@ -234,7 +238,7 @@ impl MetalDevice { // The slice might not live long enough for metal // To actually fill the GPU buffer. // Putting this wait forces the GPU buffer to be filled - // with the actual data allowing the CPU storage todo + // with the actual data allowing the CPU storage to do // deallocate properly. self.wait_until_completed()?; Ok(real) @@ -1542,7 +1546,12 @@ impl BackendDevice for MetalDevice { Ok(val) => val.parse()?, _ => 20, }; - let seed = Arc::new(Mutex::new(299792458)); + let s = device.new_buffer_with_data( + 299792458 as *const u32 as *const c_void, + 4, + MTLResourceOptions::StorageModeManaged, + )?; + let seed = Arc::new(Mutex::new(s)); Ok(Self { device, fence, @@ -1624,10 +1633,10 @@ impl BackendDevice for MetalDevice { &command_buffer, &self.kernels, name, - *self.seed.lock().unwrap(), min as f32, max as f32, shape.elem_count(), + &*self.seed.lock().unwrap(), &buffer, ) .map_err(MetalError::from)?; @@ -1655,10 +1664,10 @@ impl BackendDevice for MetalDevice { &command_buffer, &self.kernels, name, - *self.seed.lock().unwrap(), mean as f32, stddev as f32, shape.elem_count(), + &*self.seed.lock().unwrap(), &buffer, ) .map_err(MetalError::from)?; @@ -1667,8 +1676,20 @@ impl BackendDevice for MetalDevice { } fn set_seed(&self, seed: u64) -> Result<()> { + if seed > u32::MAX as u64 { + MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())? + } + let seed = seed as u32; + let mut s = self.seed.try_lock().map_err(MetalError::from)?; - *s = seed; + *s.set_purgeable_state(MTLPurgeableState::Empty); + + *s = self.device.new_buffer_with_data( + &seed as *const u32 as *const c_void, + 8, + MTLResourceOptions::StorageModeManaged, + )?; + Ok(()) } } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c427a690..6a10c333 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1587,10 +1587,10 @@ pub fn call_random_uniform( command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, - seed: u64, min: f32, max: f32, length: usize, + seed: &Buffer, buffer: &Buffer, ) -> Result<(), MetalKernelError> { if min >= max { @@ -1607,8 +1607,10 @@ pub fn call_random_uniform( encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, seed, min, max, buffer)); + set_params!(encoder, (length, min, max, seed, buffer)); + encoder.use_resource(seed, metal::MTLResourceUsage::Read); + encoder.use_resource(seed, metal::MTLResourceUsage::Write); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); @@ -1623,10 +1625,10 @@ pub fn call_random_normal( command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, - seed: u64, mean: f32, stddev: f32, length: usize, + seed: &Buffer, buffer: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Random, name)?; @@ -1638,8 +1640,10 @@ pub fn call_random_normal( encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, seed, mean, stddev, buffer)); + set_params!(encoder, (length, mean, stddev, seed, buffer)); + encoder.use_resource(seed, metal::MTLResourceUsage::Read); + encoder.use_resource(seed, metal::MTLResourceUsage::Write); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal index 5369e8e2..5eae2715 100644 --- a/candle-metal-kernels/src/random.metal +++ b/candle-metal-kernels/src/random.metal @@ -1,4 +1,7 @@ #include +#include +#include + using namespace metal; // Constants @@ -107,72 +110,85 @@ struct HybridTaus { } }; +METAL_FUNC float absdiff(float x, float y) { + return abs(x - y); +} + template METAL_FUNC void rand_uniform( constant size_t &size, - constant ulong &seed, constant float &min, constant float &max, + device atomic_uint *seed, device T *out, uint tid [[thread_position_in_grid]] ) { if (tid >= size) { return; } - float diff = max - min; - HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); + + float diff = absdiff(min, max); + HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); out[tid] = static_cast(rng.rand() * diff + min); out[size - tid] = static_cast(rng.rand() * diff + min); + + if (tid == 0) { + atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + } } // Create Gaussian normal distribution using Box-Muller transform: // https://en.wikipedia.org/wiki/Box–Muller_transform template METAL_FUNC void normal( constant size_t &size, - constant ulong &seed, constant float &mean, constant float &stddev, + device atomic_uint *seed, device T *out, uint tid [[thread_position_in_grid]] ) { if (tid >= size) { return; } - HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); + HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); float u1 = rng.rand(); float u2 = rng.rand(); float cosval; - float sinval = sincos(u1 * TWO_PI, cosval); + float sinval = sincos(TWO_PI * u2, cosval); float mag = stddev * sqrt(-2.0 * log(u1)); float z0 = mag * cosval + mean; float z1 = mag * sinval + mean; out[tid] = static_cast(z0); out[size - tid] = static_cast(z1); + + if (tid == 0) { + atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + } } #define UNIFORM_OP(NAME, T) \ kernel void rand_uniform_##NAME( \ constant size_t &size, \ - constant ulong &seed, \ constant float &min, \ constant float &max, \ + device atomic_uint *seed, \ device T *out, \ uint tid [[thread_position_in_grid]] \ ) { \ - rand_uniform(size, seed, min, max, out, tid); \ + rand_uniform(size, min, max, seed, out, tid); \ } \ #define NORMAL_OP(NAME, T) \ kernel void rand_normal_##NAME( \ constant size_t &size, \ - constant ulong &seed, \ constant float &mean, \ constant float &stddev, \ + device atomic_uint *seed, \ device T *out, \ uint tid [[thread_position_in_grid]] \ ) { \ - normal(size, seed, mean, stddev, out, tid); \ + normal(size, mean, stddev, seed, out, tid); \ } \ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 775ee0fa..2831a386 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -938,14 +938,21 @@ fn gemm() { ); } -fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec { +fn run_random(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec { let device = device(); let fence = device.new_fence(); let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; - let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + let output = device.new_buffer((length * core::mem::size_of::()) as NSUInteger, options); + + let seed = device.new_buffer_with_data( + &seed as *const u32 as *const core::ffi::c_void, + std::mem::size_of::() as NSUInteger, + options, + ); if name.starts_with("rand_uniform") { call_random_uniform( @@ -953,10 +960,10 @@ fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: command_buffer, &kernels, name, - seed, a, b, length, + &seed, &output, ) .unwrap(); @@ -966,15 +973,14 @@ fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: command_buffer, &kernels, name, - seed, a, b, length, + &seed, &output, ) .unwrap(); } - command_buffer.commit(); command_buffer.wait_until_completed(); @@ -1029,7 +1035,9 @@ fn random() { .into_iter() .map(f32::from) .collect(); - results.iter().for_each(|v| assert!(*v >= min && *v <= max)); + results.iter().for_each(|v| { + assert!(*v >= min && *v <= max); + }); assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); let results: Vec = run_random::<$type>(