mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Update metal random kernel and set_seed method
* set_seed via buffer content pointer copy + did_modify_range * ensure random.metal kernel does not write outside of buffer range when tid==0
This commit is contained in:
@ -4,11 +4,8 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
|||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
use candle_metal_kernels;
|
use candle_metal_kernels;
|
||||||
use candle_metal_kernels::Kernels;
|
use candle_metal_kernels::Kernels;
|
||||||
use cudarc::driver::DeviceRepr;
|
|
||||||
use metal;
|
use metal;
|
||||||
use metal::{
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger,
|
|
||||||
};
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
@ -1546,12 +1543,11 @@ impl BackendDevice for MetalDevice {
|
|||||||
Ok(val) => val.parse()?,
|
Ok(val) => val.parse()?,
|
||||||
_ => 20,
|
_ => 20,
|
||||||
};
|
};
|
||||||
let s = device.new_buffer_with_data(
|
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||||
299792458 as *const u32 as *const c_void,
|
[299792458].as_ptr() as *const c_void,
|
||||||
4,
|
4,
|
||||||
MTLResourceOptions::StorageModeManaged,
|
MTLResourceOptions::StorageModeManaged,
|
||||||
)?;
|
)));
|
||||||
let seed = Arc::new(Mutex::new(s));
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
fence,
|
fence,
|
||||||
@ -1676,19 +1672,16 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
if seed > u32::MAX as u64 {
|
let seed: u32 = seed.try_into().map_err(|_| {
|
||||||
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())?
|
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
|
||||||
|
let contents = seed_buffer.contents();
|
||||||
|
unsafe {
|
||||||
|
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
|
||||||
}
|
}
|
||||||
let seed = seed as u32;
|
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
|
||||||
|
|
||||||
let mut s = self.seed.try_lock().map_err(MetalError::from)?;
|
|
||||||
*s.set_purgeable_state(MTLPurgeableState::Empty);
|
|
||||||
|
|
||||||
*s = self.device.new_buffer_with_data(
|
|
||||||
&seed as *const u32 as *const c_void,
|
|
||||||
8,
|
|
||||||
MTLResourceOptions::StorageModeManaged,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ static constexpr constant int3 S1 = {13, 19, 12};
|
|||||||
static constexpr constant int3 S2 = {2, 25, 4};
|
static constexpr constant int3 S2 = {2, 25, 4};
|
||||||
static constexpr constant int3 S3 = {3, 11, 17};
|
static constexpr constant int3 S3 = {3, 11, 17};
|
||||||
|
|
||||||
|
// Used to prevent bad seeds.
|
||||||
static constexpr constant uint64_t PHI[16] = {
|
static constexpr constant uint64_t PHI[16] = {
|
||||||
0x9E3779B97F4A7C15,
|
0x9E3779B97F4A7C15,
|
||||||
0xF39CC0605CEDC834,
|
0xF39CC0605CEDC834,
|
||||||
@ -110,10 +111,6 @@ struct HybridTaus {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
METAL_FUNC float absdiff(float x, float y) {
|
|
||||||
return abs(x - y);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T> METAL_FUNC void rand_uniform(
|
template<typename T> METAL_FUNC void rand_uniform(
|
||||||
constant size_t &size,
|
constant size_t &size,
|
||||||
constant float &min,
|
constant float &min,
|
||||||
@ -126,14 +123,16 @@ template<typename T> METAL_FUNC void rand_uniform(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float diff = absdiff(min, max);
|
float diff = abs(min - max);
|
||||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
||||||
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||||
|
// Return early if tid == 0, otherwise we will write to out[size].
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
// Use symmetry to fill the other half of the array.
|
||||||
|
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create Gaussian normal distribution using Box-Muller transform:
|
// Create Gaussian normal distribution using Box-Muller transform:
|
||||||
@ -160,11 +159,14 @@ template<typename T> METAL_FUNC void normal(
|
|||||||
float z1 = mag * sinval + mean;
|
float z1 = mag * sinval + mean;
|
||||||
|
|
||||||
out[tid] = static_cast<T>(z0);
|
out[tid] = static_cast<T>(z0);
|
||||||
out[size - tid] = static_cast<T>(z1);
|
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||||
|
// Return early if tid == 0, otherwise we will write to out[size].
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
// Use symmetry to fill the other half of the array.
|
||||||
|
out[size - tid] = static_cast<T>(z1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define UNIFORM_OP(NAME, T) \
|
#define UNIFORM_OP(NAME, T) \
|
||||||
|
Reference in New Issue
Block a user