mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Seed should be updated by random kernel result.
This commit is contained in:
@ -4,9 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use cudarc::driver::DeviceRepr;
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use metal::{
|
||||
Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
|
||||
@ -107,7 +111,7 @@ pub struct MetalDevice {
|
||||
/// (strong_count = 1).
|
||||
buffers: AllocatedBuffers,
|
||||
/// Seed for random number generation.
|
||||
seed: Arc<Mutex<u64>>,
|
||||
seed: Arc<Mutex<Buffer>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
@ -1542,7 +1546,12 @@ impl BackendDevice for MetalDevice {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 20,
|
||||
};
|
||||
let seed = Arc::new(Mutex::new(299792458));
|
||||
let s = device.new_buffer_with_data(
|
||||
299792458 as *const u32 as *const c_void,
|
||||
4,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
)?;
|
||||
let seed = Arc::new(Mutex::new(s));
|
||||
Ok(Self {
|
||||
device,
|
||||
fence,
|
||||
@ -1624,10 +1633,10 @@ impl BackendDevice for MetalDevice {
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
*self.seed.lock().unwrap(),
|
||||
min as f32,
|
||||
max as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1655,10 +1664,10 @@ impl BackendDevice for MetalDevice {
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
*self.seed.lock().unwrap(),
|
||||
mean as f32,
|
||||
stddev as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1667,8 +1676,20 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
if seed > u32::MAX as u64 {
|
||||
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())?
|
||||
}
|
||||
let seed = seed as u32;
|
||||
|
||||
let mut s = self.seed.try_lock().map_err(MetalError::from)?;
|
||||
*s = seed;
|
||||
*s.set_purgeable_state(MTLPurgeableState::Empty);
|
||||
|
||||
*s = self.device.new_buffer_with_data(
|
||||
&seed as *const u32 as *const c_void,
|
||||
8,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -1587,10 +1587,10 @@ pub fn call_random_uniform(
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
seed: u64,
|
||||
min: f32,
|
||||
max: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
if min >= max {
|
||||
@ -1607,8 +1607,10 @@ pub fn call_random_uniform(
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, seed, min, max, buffer));
|
||||
set_params!(encoder, (length, min, max, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
@ -1623,10 +1625,10 @@ pub fn call_random_normal(
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
seed: u64,
|
||||
mean: f32,
|
||||
stddev: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
@ -1638,8 +1640,10 @@ pub fn call_random_normal(
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, seed, mean, stddev, buffer));
|
||||
set_params!(encoder, (length, mean, stddev, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
|
@ -1,4 +1,7 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_integer>
|
||||
#include <metal_atomic>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// Constants
|
||||
@ -107,72 +110,85 @@ struct HybridTaus {
|
||||
}
|
||||
};
|
||||
|
||||
METAL_FUNC float absdiff(float x, float y) {
|
||||
return abs(x - y);
|
||||
}
|
||||
|
||||
template<typename T> METAL_FUNC void rand_uniform(
|
||||
constant size_t &size,
|
||||
constant ulong &seed,
|
||||
constant float &min,
|
||||
constant float &max,
|
||||
device atomic_uint *seed,
|
||||
device T *out,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
float diff = max - min;
|
||||
HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
|
||||
|
||||
float diff = absdiff(min, max);
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
// Create Gaussian normal distribution using Box-Muller transform:
|
||||
// https://en.wikipedia.org/wiki/Box–Muller_transform
|
||||
template<typename T> METAL_FUNC void normal(
|
||||
constant size_t &size,
|
||||
constant ulong &seed,
|
||||
constant float &mean,
|
||||
constant float &stddev,
|
||||
device atomic_uint *seed,
|
||||
device T *out,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
float u1 = rng.rand();
|
||||
float u2 = rng.rand();
|
||||
|
||||
float cosval;
|
||||
float sinval = sincos(u1 * TWO_PI, cosval);
|
||||
float sinval = sincos(TWO_PI * u2, cosval);
|
||||
float mag = stddev * sqrt(-2.0 * log(u1));
|
||||
float z0 = mag * cosval + mean;
|
||||
float z1 = mag * sinval + mean;
|
||||
|
||||
out[tid] = static_cast<T>(z0);
|
||||
out[size - tid] = static_cast<T>(z1);
|
||||
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#define UNIFORM_OP(NAME, T) \
|
||||
kernel void rand_uniform_##NAME( \
|
||||
constant size_t &size, \
|
||||
constant ulong &seed, \
|
||||
constant float &min, \
|
||||
constant float &max, \
|
||||
device atomic_uint *seed, \
|
||||
device T *out, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
rand_uniform<T>(size, seed, min, max, out, tid); \
|
||||
rand_uniform<T>(size, min, max, seed, out, tid); \
|
||||
} \
|
||||
|
||||
#define NORMAL_OP(NAME, T) \
|
||||
kernel void rand_normal_##NAME( \
|
||||
constant size_t &size, \
|
||||
constant ulong &seed, \
|
||||
constant float &mean, \
|
||||
constant float &stddev, \
|
||||
device atomic_uint *seed, \
|
||||
device T *out, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
normal<T>(size, seed, mean, stddev, out, tid); \
|
||||
normal<T>(size, mean, stddev, seed, out, tid); \
|
||||
} \
|
||||
|
||||
|
||||
|
@ -938,14 +938,21 @@ fn gemm() {
|
||||
);
|
||||
}
|
||||
|
||||
fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
|
||||
|
||||
let seed = device.new_buffer_with_data(
|
||||
&seed as *const u32 as *const core::ffi::c_void,
|
||||
std::mem::size_of::<u32>() as NSUInteger,
|
||||
options,
|
||||
);
|
||||
|
||||
if name.starts_with("rand_uniform") {
|
||||
call_random_uniform(
|
||||
@ -953,10 +960,10 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
seed,
|
||||
a,
|
||||
b,
|
||||
length,
|
||||
&seed,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -966,15 +973,14 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
seed,
|
||||
a,
|
||||
b,
|
||||
length,
|
||||
&seed,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
@ -1029,7 +1035,9 @@ fn random() {
|
||||
.into_iter()
|
||||
.map(f32::from)
|
||||
.collect();
|
||||
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
|
||||
results.iter().for_each(|v| {
|
||||
assert!(*v >= min && *v <= max);
|
||||
});
|
||||
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
||||
|
||||
let results: Vec<f32> = run_random::<$type>(
|
||||
|
Reference in New Issue
Block a user