Update metal random kernel and set_seed method

* set_seed via buffer content pointer copy + did_modify_range

* ensure random.metal kernel does not write outside of buffer range when tid==0
This commit is contained in:
Ivar Flakstad
2024-01-16 19:11:31 +01:00
parent 79478ff5a1
commit 86a8e58897
2 changed files with 23 additions and 28 deletions

View File

@ -4,11 +4,8 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape}; use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels; use candle_metal_kernels;
use candle_metal_kernels::Kernels; use candle_metal_kernels::Kernels;
use cudarc::driver::DeviceRepr;
use metal; use metal;
use metal::{ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger,
};
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void; use std::ffi::c_void;
use std::path::Path; use std::path::Path;
@ -1546,12 +1543,11 @@ impl BackendDevice for MetalDevice {
Ok(val) => val.parse()?, Ok(val) => val.parse()?,
_ => 20, _ => 20,
}; };
let s = device.new_buffer_with_data( let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
299792458 as *const u32 as *const c_void, [299792458].as_ptr() as *const c_void,
4, 4,
MTLResourceOptions::StorageModeManaged, MTLResourceOptions::StorageModeManaged,
)?; )));
let seed = Arc::new(Mutex::new(s));
Ok(Self { Ok(Self {
device, device,
fence, fence,
@ -1676,19 +1672,16 @@ impl BackendDevice for MetalDevice {
} }
fn set_seed(&self, seed: u64) -> Result<()> { fn set_seed(&self, seed: u64) -> Result<()> {
if seed > u32::MAX as u64 { let seed: u32 = seed.try_into().map_err(|_| {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())? MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
})?;
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
let contents = seed_buffer.contents();
unsafe {
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
} }
let seed = seed as u32; seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
let mut s = self.seed.try_lock().map_err(MetalError::from)?;
*s.set_purgeable_state(MTLPurgeableState::Empty);
*s = self.device.new_buffer_with_data(
&seed as *const u32 as *const c_void,
8,
MTLResourceOptions::StorageModeManaged,
)?;
Ok(()) Ok(())
} }

View File

@ -14,6 +14,7 @@ static constexpr constant int3 S1 = {13, 19, 12};
static constexpr constant int3 S2 = {2, 25, 4}; static constexpr constant int3 S2 = {2, 25, 4};
static constexpr constant int3 S3 = {3, 11, 17}; static constexpr constant int3 S3 = {3, 11, 17};
// Used to prevent bad seeds.
static constexpr constant uint64_t PHI[16] = { static constexpr constant uint64_t PHI[16] = {
0x9E3779B97F4A7C15, 0x9E3779B97F4A7C15,
0xF39CC0605CEDC834, 0xF39CC0605CEDC834,
@ -110,10 +111,6 @@ struct HybridTaus {
} }
}; };
METAL_FUNC float absdiff(float x, float y) {
return abs(x - y);
}
template<typename T> METAL_FUNC void rand_uniform( template<typename T> METAL_FUNC void rand_uniform(
constant size_t &size, constant size_t &size,
constant float &min, constant float &min,
@ -126,14 +123,16 @@ template<typename T> METAL_FUNC void rand_uniform(
return; return;
} }
float diff = absdiff(min, max); float diff = abs(min - max);
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
out[tid] = static_cast<T>(rng.rand() * diff + min); out[tid] = static_cast<T>(rng.rand() * diff + min);
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
if (tid == 0) { if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
// Return early if tid == 0, otherwise we will write to out[size].
return;
} }
// Use symmetry to fill the other half of the array.
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
} }
// Create Gaussian normal distribution using Box-Muller transform: // Create Gaussian normal distribution using Box-Muller transform:
@ -160,11 +159,14 @@ template<typename T> METAL_FUNC void normal(
float z1 = mag * sinval + mean; float z1 = mag * sinval + mean;
out[tid] = static_cast<T>(z0); out[tid] = static_cast<T>(z0);
out[size - tid] = static_cast<T>(z1);
if (tid == 0) { if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
// Return early if tid == 0, otherwise we will write to out[size].
return;
} }
// Use symmetry to fill the other half of the array.
out[size - tid] = static_cast<T>(z1);
} }
#define UNIFORM_OP(NAME, T) \ #define UNIFORM_OP(NAME, T) \