Seed should be updated by random kernel result.

This commit is contained in:
Ivar Flakstad
2024-01-14 18:10:54 +01:00
parent ecf88a6d38
commit 79478ff5a1
4 changed files with 76 additions and 27 deletions

View File

@ -4,9 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use cudarc::driver::DeviceRepr;
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use metal::{
Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, TryLockError};
@ -107,7 +111,7 @@ pub struct MetalDevice {
/// (strong_count = 1).
buffers: AllocatedBuffers,
/// Seed for random number generation.
seed: Arc<Mutex<u64>>,
seed: Arc<Mutex<Buffer>>,
}
impl std::fmt::Debug for MetalDevice {
@ -234,7 +238,7 @@ impl MetalDevice {
// The slice might not live long enough for metal
// To actually fill the GPU buffer.
// Putting this wait forces the GPU buffer to be filled
// with the actual data allowing the CPU storage todo
// with the actual data allowing the CPU storage to do
// deallocate properly.
self.wait_until_completed()?;
Ok(real)
@ -1542,7 +1546,12 @@ impl BackendDevice for MetalDevice {
Ok(val) => val.parse()?,
_ => 20,
};
let seed = Arc::new(Mutex::new(299792458));
let s = device.new_buffer_with_data(
299792458 as *const u32 as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)?;
let seed = Arc::new(Mutex::new(s));
Ok(Self {
device,
fence,
@ -1624,10 +1633,10 @@ impl BackendDevice for MetalDevice {
&command_buffer,
&self.kernels,
name,
*self.seed.lock().unwrap(),
min as f32,
max as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
@ -1655,10 +1664,10 @@ impl BackendDevice for MetalDevice {
&command_buffer,
&self.kernels,
name,
*self.seed.lock().unwrap(),
mean as f32,
stddev as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
@ -1667,8 +1676,20 @@ impl BackendDevice for MetalDevice {
}
fn set_seed(&self, seed: u64) -> Result<()> {
if seed > u32::MAX as u64 {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())?
}
let seed = seed as u32;
let mut s = self.seed.try_lock().map_err(MetalError::from)?;
*s = seed;
*s.set_purgeable_state(MTLPurgeableState::Empty);
*s = self.device.new_buffer_with_data(
&seed as *const u32 as *const c_void,
8,
MTLResourceOptions::StorageModeManaged,
)?;
Ok(())
}
}

View File

@ -1587,10 +1587,10 @@ pub fn call_random_uniform(
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
seed: u64,
min: f32,
max: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
if min >= max {
@ -1607,8 +1607,10 @@ pub fn call_random_uniform(
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, seed, min, max, buffer));
set_params!(encoder, (length, min, max, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
@ -1623,10 +1625,10 @@ pub fn call_random_normal(
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
seed: u64,
mean: f32,
stddev: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
@ -1638,8 +1640,10 @@ pub fn call_random_normal(
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, seed, mean, stddev, buffer));
set_params!(encoder, (length, mean, stddev, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);

View File

@ -1,4 +1,7 @@
#include <metal_stdlib>
#include <metal_integer>
#include <metal_atomic>
using namespace metal;
// Constants
@ -107,72 +110,85 @@ struct HybridTaus {
}
};
METAL_FUNC float absdiff(float x, float y) {
return abs(x - y);
}
template<typename T> METAL_FUNC void rand_uniform(
constant size_t &size,
constant ulong &seed,
constant float &min,
constant float &max,
device atomic_uint *seed,
device T *out,
uint tid [[thread_position_in_grid]]
) {
if (tid >= size) {
return;
}
float diff = max - min;
HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
float diff = absdiff(min, max);
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
out[tid] = static_cast<T>(rng.rand() * diff + min);
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
}
}
// Create Gaussian normal distribution using Box-Muller transform:
// https://en.wikipedia.org/wiki/BoxMuller_transform
template<typename T> METAL_FUNC void normal(
constant size_t &size,
constant ulong &seed,
constant float &mean,
constant float &stddev,
device atomic_uint *seed,
device T *out,
uint tid [[thread_position_in_grid]]
) {
if (tid >= size) {
return;
}
HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
float u1 = rng.rand();
float u2 = rng.rand();
float cosval;
float sinval = sincos(u1 * TWO_PI, cosval);
float sinval = sincos(TWO_PI * u2, cosval);
float mag = stddev * sqrt(-2.0 * log(u1));
float z0 = mag * cosval + mean;
float z1 = mag * sinval + mean;
out[tid] = static_cast<T>(z0);
out[size - tid] = static_cast<T>(z1);
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
}
}
#define UNIFORM_OP(NAME, T) \
kernel void rand_uniform_##NAME( \
constant size_t &size, \
constant ulong &seed, \
constant float &min, \
constant float &max, \
device atomic_uint *seed, \
device T *out, \
uint tid [[thread_position_in_grid]] \
) { \
rand_uniform<T>(size, seed, min, max, out, tid); \
rand_uniform<T>(size, min, max, seed, out, tid); \
} \
#define NORMAL_OP(NAME, T) \
kernel void rand_normal_##NAME( \
constant size_t &size, \
constant ulong &seed, \
constant float &mean, \
constant float &stddev, \
device atomic_uint *seed, \
device T *out, \
uint tid [[thread_position_in_grid]] \
) { \
normal<T>(size, seed, mean, stddev, out, tid); \
normal<T>(size, mean, stddev, seed, out, tid); \
} \

View File

@ -938,14 +938,21 @@ fn gemm() {
);
}
fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> {
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
let seed = device.new_buffer_with_data(
&seed as *const u32 as *const core::ffi::c_void,
std::mem::size_of::<u32>() as NSUInteger,
options,
);
if name.starts_with("rand_uniform") {
call_random_uniform(
@ -953,10 +960,10 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
command_buffer,
&kernels,
name,
seed,
a,
b,
length,
&seed,
&output,
)
.unwrap();
@ -966,15 +973,14 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
command_buffer,
&kernels,
name,
seed,
a,
b,
length,
&seed,
&output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
@ -1029,7 +1035,9 @@ fn random() {
.into_iter()
.map(f32::from)
.collect();
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
results.iter().for_each(|v| {
assert!(*v >= min && *v <= max);
});
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
let results: Vec<f32> = run_random::<$type>(