Update metal random kernel and set_seed method

* set_seed via buffer content pointer copy + did_modify_range

* ensure random.metal kernel does not write outside of buffer range when tid==0
This commit is contained in:
Ivar Flakstad
2024-01-16 19:11:31 +01:00
parent 79478ff5a1
commit 86a8e58897
2 changed files with 23 additions and 28 deletions

View File

@ -4,11 +4,8 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use cudarc::driver::DeviceRepr;
use metal;
use metal::{
Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger,
};
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
@ -1546,12 +1543,11 @@ impl BackendDevice for MetalDevice {
Ok(val) => val.parse()?,
_ => 20,
};
let s = device.new_buffer_with_data(
299792458 as *const u32 as *const c_void,
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)?;
let seed = Arc::new(Mutex::new(s));
)));
Ok(Self {
device,
fence,
@ -1676,19 +1672,16 @@ impl BackendDevice for MetalDevice {
}
fn set_seed(&self, seed: u64) -> Result<()> {
if seed > u32::MAX as u64 {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())?
let seed: u32 = seed.try_into().map_err(|_| {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
})?;
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
let contents = seed_buffer.contents();
unsafe {
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
}
let seed = seed as u32;
let mut s = self.seed.try_lock().map_err(MetalError::from)?;
*s.set_purgeable_state(MTLPurgeableState::Empty);
*s = self.device.new_buffer_with_data(
&seed as *const u32 as *const c_void,
8,
MTLResourceOptions::StorageModeManaged,
)?;
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
Ok(())
}