Seed should be updated by random kernel result.

This commit is contained in:
Ivar Flakstad
2024-01-14 18:10:54 +01:00
parent ecf88a6d38
commit 79478ff5a1
4 changed files with 76 additions and 27 deletions

View File

@ -4,9 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape}; use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels; use candle_metal_kernels;
use candle_metal_kernels::Kernels; use candle_metal_kernels::Kernels;
use cudarc::driver::DeviceRepr;
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{
Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger,
};
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path; use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, TryLockError}; use std::sync::{Arc, Mutex, RwLock, TryLockError};
@ -107,7 +111,7 @@ pub struct MetalDevice {
/// (strong_count = 1). /// (strong_count = 1).
buffers: AllocatedBuffers, buffers: AllocatedBuffers,
/// Seed for random number generation. /// Seed for random number generation.
seed: Arc<Mutex<u64>>, seed: Arc<Mutex<Buffer>>,
} }
impl std::fmt::Debug for MetalDevice { impl std::fmt::Debug for MetalDevice {
@ -1542,7 +1546,12 @@ impl BackendDevice for MetalDevice {
Ok(val) => val.parse()?, Ok(val) => val.parse()?,
_ => 20, _ => 20,
}; };
let seed = Arc::new(Mutex::new(299792458)); let s = device.new_buffer_with_data(
299792458 as *const u32 as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)?;
let seed = Arc::new(Mutex::new(s));
Ok(Self { Ok(Self {
device, device,
fence, fence,
@ -1624,10 +1633,10 @@ impl BackendDevice for MetalDevice {
&command_buffer, &command_buffer,
&self.kernels, &self.kernels,
name, name,
*self.seed.lock().unwrap(),
min as f32, min as f32,
max as f32, max as f32,
shape.elem_count(), shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -1655,10 +1664,10 @@ impl BackendDevice for MetalDevice {
&command_buffer, &command_buffer,
&self.kernels, &self.kernels,
name, name,
*self.seed.lock().unwrap(),
mean as f32, mean as f32,
stddev as f32, stddev as f32,
shape.elem_count(), shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -1667,8 +1676,20 @@ impl BackendDevice for MetalDevice {
} }
fn set_seed(&self, seed: u64) -> Result<()> { fn set_seed(&self, seed: u64) -> Result<()> {
if seed > u32::MAX as u64 {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())?
}
let seed = seed as u32;
let mut s = self.seed.try_lock().map_err(MetalError::from)?; let mut s = self.seed.try_lock().map_err(MetalError::from)?;
*s = seed; *s.set_purgeable_state(MTLPurgeableState::Empty);
*s = self.device.new_buffer_with_data(
&seed as *const u32 as *const c_void,
8,
MTLResourceOptions::StorageModeManaged,
)?;
Ok(()) Ok(())
} }
} }

View File

@ -1587,10 +1587,10 @@ pub fn call_random_uniform(
command_buffer: &CommandBufferRef, command_buffer: &CommandBufferRef,
kernels: &Kernels, kernels: &Kernels,
name: &'static str, name: &'static str,
seed: u64,
min: f32, min: f32,
max: f32, max: f32,
length: usize, length: usize,
seed: &Buffer,
buffer: &Buffer, buffer: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
if min >= max { if min >= max {
@ -1607,8 +1607,10 @@ pub fn call_random_uniform(
encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, seed, min, max, buffer)); set_params!(encoder, (length, min, max, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
@ -1623,10 +1625,10 @@ pub fn call_random_normal(
command_buffer: &CommandBufferRef, command_buffer: &CommandBufferRef,
kernels: &Kernels, kernels: &Kernels,
name: &'static str, name: &'static str,
seed: u64,
mean: f32, mean: f32,
stddev: f32, stddev: f32,
length: usize, length: usize,
seed: &Buffer,
buffer: &Buffer, buffer: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?; let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
@ -1638,8 +1640,10 @@ pub fn call_random_normal(
encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, seed, mean, stddev, buffer)); set_params!(encoder, (length, mean, stddev, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);

View File

@ -1,4 +1,7 @@
#include <metal_stdlib> #include <metal_stdlib>
#include <metal_integer>
#include <metal_atomic>
using namespace metal; using namespace metal;
// Constants // Constants
@ -107,72 +110,85 @@ struct HybridTaus {
} }
}; };
METAL_FUNC float absdiff(float x, float y) {
return abs(x - y);
}
template<typename T> METAL_FUNC void rand_uniform( template<typename T> METAL_FUNC void rand_uniform(
constant size_t &size, constant size_t &size,
constant ulong &seed,
constant float &min, constant float &min,
constant float &max, constant float &max,
device atomic_uint *seed,
device T *out, device T *out,
uint tid [[thread_position_in_grid]] uint tid [[thread_position_in_grid]]
) { ) {
if (tid >= size) { if (tid >= size) {
return; return;
} }
float diff = max - min;
HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); float diff = absdiff(min, max);
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
out[tid] = static_cast<T>(rng.rand() * diff + min); out[tid] = static_cast<T>(rng.rand() * diff + min);
out[size - tid] = static_cast<T>(rng.rand() * diff + min); out[size - tid] = static_cast<T>(rng.rand() * diff + min);
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
}
} }
// Create Gaussian normal distribution using Box-Muller transform: // Create Gaussian normal distribution using Box-Muller transform:
// https://en.wikipedia.org/wiki/BoxMuller_transform // https://en.wikipedia.org/wiki/BoxMuller_transform
template<typename T> METAL_FUNC void normal( template<typename T> METAL_FUNC void normal(
constant size_t &size, constant size_t &size,
constant ulong &seed,
constant float &mean, constant float &mean,
constant float &stddev, constant float &stddev,
device atomic_uint *seed,
device T *out, device T *out,
uint tid [[thread_position_in_grid]] uint tid [[thread_position_in_grid]]
) { ) {
if (tid >= size) { if (tid >= size) {
return; return;
} }
HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
float u1 = rng.rand(); float u1 = rng.rand();
float u2 = rng.rand(); float u2 = rng.rand();
float cosval; float cosval;
float sinval = sincos(u1 * TWO_PI, cosval); float sinval = sincos(TWO_PI * u2, cosval);
float mag = stddev * sqrt(-2.0 * log(u1)); float mag = stddev * sqrt(-2.0 * log(u1));
float z0 = mag * cosval + mean; float z0 = mag * cosval + mean;
float z1 = mag * sinval + mean; float z1 = mag * sinval + mean;
out[tid] = static_cast<T>(z0); out[tid] = static_cast<T>(z0);
out[size - tid] = static_cast<T>(z1); out[size - tid] = static_cast<T>(z1);
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
}
} }
#define UNIFORM_OP(NAME, T) \ #define UNIFORM_OP(NAME, T) \
kernel void rand_uniform_##NAME( \ kernel void rand_uniform_##NAME( \
constant size_t &size, \ constant size_t &size, \
constant ulong &seed, \
constant float &min, \ constant float &min, \
constant float &max, \ constant float &max, \
device atomic_uint *seed, \
device T *out, \ device T *out, \
uint tid [[thread_position_in_grid]] \ uint tid [[thread_position_in_grid]] \
) { \ ) { \
rand_uniform<T>(size, seed, min, max, out, tid); \ rand_uniform<T>(size, min, max, seed, out, tid); \
} \ } \
#define NORMAL_OP(NAME, T) \ #define NORMAL_OP(NAME, T) \
kernel void rand_normal_##NAME( \ kernel void rand_normal_##NAME( \
constant size_t &size, \ constant size_t &size, \
constant ulong &seed, \
constant float &mean, \ constant float &mean, \
constant float &stddev, \ constant float &stddev, \
device atomic_uint *seed, \
device T *out, \ device T *out, \
uint tid [[thread_position_in_grid]] \ uint tid [[thread_position_in_grid]] \
) { \ ) { \
normal<T>(size, seed, mean, stddev, out, tid); \ normal<T>(size, mean, stddev, seed, out, tid); \
} \ } \

View File

@ -938,14 +938,21 @@ fn gemm() {
); );
} }
fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> { fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device(); let device = device();
let fence = device.new_fence(); let fence = device.new_fence();
let kernels = Kernels::new(fence); let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
let seed = device.new_buffer_with_data(
&seed as *const u32 as *const core::ffi::c_void,
std::mem::size_of::<u32>() as NSUInteger,
options,
);
if name.starts_with("rand_uniform") { if name.starts_with("rand_uniform") {
call_random_uniform( call_random_uniform(
@ -953,10 +960,10 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
command_buffer, command_buffer,
&kernels, &kernels,
name, name,
seed,
a, a,
b, b,
length, length,
&seed,
&output, &output,
) )
.unwrap(); .unwrap();
@ -966,15 +973,14 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
command_buffer, command_buffer,
&kernels, &kernels,
name, name,
seed,
a, a,
b, b,
length, length,
&seed,
&output, &output,
) )
.unwrap(); .unwrap();
} }
command_buffer.commit(); command_buffer.commit();
command_buffer.wait_until_completed(); command_buffer.wait_until_completed();
@ -1029,7 +1035,9 @@ fn random() {
.into_iter() .into_iter()
.map(f32::from) .map(f32::from)
.collect(); .collect();
results.iter().for_each(|v| assert!(*v >= min && *v <= max)); results.iter().for_each(|v| {
assert!(*v >= min && *v <= max);
});
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
let results: Vec<f32> = run_random::<$type>( let results: Vec<f32> = run_random::<$type>(