diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 673e6e11..aa97c04a 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,11 +4,8 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; -use cudarc::driver::DeviceRepr; use metal; -use metal::{ - Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger, -}; +use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::ffi::c_void; use std::path::Path; @@ -1546,12 +1543,11 @@ impl BackendDevice for MetalDevice { Ok(val) => val.parse()?, _ => 20, }; - let s = device.new_buffer_with_data( - 299792458 as *const u32 as *const c_void, + let seed = Arc::new(Mutex::new(device.new_buffer_with_data( + [299792458].as_ptr() as *const c_void, 4, MTLResourceOptions::StorageModeManaged, - )?; - let seed = Arc::new(Mutex::new(s)); + ))); Ok(Self { device, fence, @@ -1676,19 +1672,16 @@ impl BackendDevice for MetalDevice { } fn set_seed(&self, seed: u64) -> Result<()> { - if seed > u32::MAX as u64 { - MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())? + let seed: u32 = seed.try_into().map_err(|_| { + MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string()) + })?; + + let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?; + let contents = seed_buffer.contents(); + unsafe { + std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4); } - let seed = seed as u32; - - let mut s = self.seed.try_lock().map_err(MetalError::from)?; - *s.set_purgeable_state(MTLPurgeableState::Empty); - - *s = self.device.new_buffer_with_data( - &seed as *const u32 as *const c_void, - 8, - MTLResourceOptions::StorageModeManaged, - )?; + seed_buffer.did_modify_range(metal::NSRange::new(0, 4)); Ok(()) } diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal index 5eae2715..a7e48393 100644 --- a/candle-metal-kernels/src/random.metal +++ b/candle-metal-kernels/src/random.metal @@ -14,6 +14,7 @@ static constexpr constant int3 S1 = {13, 19, 12}; static constexpr constant int3 S2 = {2, 25, 4}; static constexpr constant int3 S3 = {3, 11, 17}; +// Used to prevent bad seeds. static constexpr constant uint64_t PHI[16] = { 0x9E3779B97F4A7C15, 0xF39CC0605CEDC834, @@ -110,10 +111,6 @@ struct HybridTaus { } }; -METAL_FUNC float absdiff(float x, float y) { - return abs(x - y); -} - template METAL_FUNC void rand_uniform( constant size_t &size, constant float &min, @@ -126,14 +123,16 @@ template METAL_FUNC void rand_uniform( return; } - float diff = absdiff(min, max); + float diff = abs(min - max); HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); out[tid] = static_cast(rng.rand() * diff + min); - out[size - tid] = static_cast(rng.rand() * diff + min); - if (tid == 0) { atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + // Return early if tid == 0, otherwise we will write to out[size]. + return; } + // Use symmetry to fill the other half of the array. + out[size - tid] = static_cast(rng.rand() * diff + min); } // Create Gaussian normal distribution using Box-Muller transform: @@ -160,11 +159,14 @@ template METAL_FUNC void normal( float z1 = mag * sinval + mean; out[tid] = static_cast(z0); - out[size - tid] = static_cast(z1); if (tid == 0) { atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + // Return early if tid == 0, otherwise we will write to out[size]. + return; } + // Use symmetry to fill the other half of the array. + out[size - tid] = static_cast(z1); } #define UNIFORM_OP(NAME, T) \