Seed should be updated by random kernel result.

This commit is contained in:
Ivar Flakstad
2024-01-14 18:10:54 +01:00
parent ecf88a6d38
commit 79478ff5a1
4 changed files with 76 additions and 27 deletions

View File

@ -4,9 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use cudarc::driver::DeviceRepr;
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use metal::{
Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, TryLockError};
@ -107,7 +111,7 @@ pub struct MetalDevice {
/// (strong_count = 1).
buffers: AllocatedBuffers,
/// Seed for random number generation.
seed: Arc<Mutex<u64>>,
seed: Arc<Mutex<Buffer>>,
}
impl std::fmt::Debug for MetalDevice {
@ -234,7 +238,7 @@ impl MetalDevice {
// The slice might not live long enough for metal
// To actually fill the GPU buffer.
// Putting this wait forces the GPU buffer to be filled
// with the actual data allowing the CPU storage todo
// with the actual data allowing the CPU storage to do
// deallocate properly.
self.wait_until_completed()?;
Ok(real)
@ -1542,7 +1546,12 @@ impl BackendDevice for MetalDevice {
Ok(val) => val.parse()?,
_ => 20,
};
let seed = Arc::new(Mutex::new(299792458));
let s = device.new_buffer_with_data(
299792458 as *const u32 as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)?;
let seed = Arc::new(Mutex::new(s));
Ok(Self {
device,
fence,
@ -1624,10 +1633,10 @@ impl BackendDevice for MetalDevice {
&command_buffer,
&self.kernels,
name,
*self.seed.lock().unwrap(),
min as f32,
max as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
@ -1655,10 +1664,10 @@ impl BackendDevice for MetalDevice {
&command_buffer,
&self.kernels,
name,
*self.seed.lock().unwrap(),
mean as f32,
stddev as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
@ -1667,8 +1676,20 @@ impl BackendDevice for MetalDevice {
}
fn set_seed(&self, seed: u64) -> Result<()> {
if seed > u32::MAX as u64 {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())?
}
let seed = seed as u32;
let mut s = self.seed.try_lock().map_err(MetalError::from)?;
*s = seed;
*s.set_purgeable_state(MTLPurgeableState::Empty);
*s = self.device.new_buffer_with_data(
&seed as *const u32 as *const c_void,
8,
MTLResourceOptions::StorageModeManaged,
)?;
Ok(())
}
}