mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Seed should be updated by random kernel result.
This commit is contained in:
@ -4,9 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
|||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
use candle_metal_kernels;
|
use candle_metal_kernels;
|
||||||
use candle_metal_kernels::Kernels;
|
use candle_metal_kernels::Kernels;
|
||||||
|
use cudarc::driver::DeviceRepr;
|
||||||
use metal;
|
use metal;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{
|
||||||
|
Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger,
|
||||||
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::ffi::c_void;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||||
|
|
||||||
@ -107,7 +111,7 @@ pub struct MetalDevice {
|
|||||||
/// (strong_count = 1).
|
/// (strong_count = 1).
|
||||||
buffers: AllocatedBuffers,
|
buffers: AllocatedBuffers,
|
||||||
/// Seed for random number generation.
|
/// Seed for random number generation.
|
||||||
seed: Arc<Mutex<u64>>,
|
seed: Arc<Mutex<Buffer>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
impl std::fmt::Debug for MetalDevice {
|
||||||
@ -234,7 +238,7 @@ impl MetalDevice {
|
|||||||
// The slice might not live long enough for metal
|
// The slice might not live long enough for metal
|
||||||
// To actually fill the GPU buffer.
|
// To actually fill the GPU buffer.
|
||||||
// Putting this wait forces the GPU buffer to be filled
|
// Putting this wait forces the GPU buffer to be filled
|
||||||
// with the actual data allowing the CPU storage todo
|
// with the actual data allowing the CPU storage to do
|
||||||
// deallocate properly.
|
// deallocate properly.
|
||||||
self.wait_until_completed()?;
|
self.wait_until_completed()?;
|
||||||
Ok(real)
|
Ok(real)
|
||||||
@ -1542,7 +1546,12 @@ impl BackendDevice for MetalDevice {
|
|||||||
Ok(val) => val.parse()?,
|
Ok(val) => val.parse()?,
|
||||||
_ => 20,
|
_ => 20,
|
||||||
};
|
};
|
||||||
let seed = Arc::new(Mutex::new(299792458));
|
let s = device.new_buffer_with_data(
|
||||||
|
299792458 as *const u32 as *const c_void,
|
||||||
|
4,
|
||||||
|
MTLResourceOptions::StorageModeManaged,
|
||||||
|
)?;
|
||||||
|
let seed = Arc::new(Mutex::new(s));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
fence,
|
fence,
|
||||||
@ -1624,10 +1633,10 @@ impl BackendDevice for MetalDevice {
|
|||||||
&command_buffer,
|
&command_buffer,
|
||||||
&self.kernels,
|
&self.kernels,
|
||||||
name,
|
name,
|
||||||
*self.seed.lock().unwrap(),
|
|
||||||
min as f32,
|
min as f32,
|
||||||
max as f32,
|
max as f32,
|
||||||
shape.elem_count(),
|
shape.elem_count(),
|
||||||
|
&*self.seed.lock().unwrap(),
|
||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
@ -1655,10 +1664,10 @@ impl BackendDevice for MetalDevice {
|
|||||||
&command_buffer,
|
&command_buffer,
|
||||||
&self.kernels,
|
&self.kernels,
|
||||||
name,
|
name,
|
||||||
*self.seed.lock().unwrap(),
|
|
||||||
mean as f32,
|
mean as f32,
|
||||||
stddev as f32,
|
stddev as f32,
|
||||||
shape.elem_count(),
|
shape.elem_count(),
|
||||||
|
&*self.seed.lock().unwrap(),
|
||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
@ -1667,8 +1676,20 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
|
if seed > u32::MAX as u64 {
|
||||||
|
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())?
|
||||||
|
}
|
||||||
|
let seed = seed as u32;
|
||||||
|
|
||||||
let mut s = self.seed.try_lock().map_err(MetalError::from)?;
|
let mut s = self.seed.try_lock().map_err(MetalError::from)?;
|
||||||
*s = seed;
|
*s.set_purgeable_state(MTLPurgeableState::Empty);
|
||||||
|
|
||||||
|
*s = self.device.new_buffer_with_data(
|
||||||
|
&seed as *const u32 as *const c_void,
|
||||||
|
8,
|
||||||
|
MTLResourceOptions::StorageModeManaged,
|
||||||
|
)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1587,10 +1587,10 @@ pub fn call_random_uniform(
|
|||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
seed: u64,
|
|
||||||
min: f32,
|
min: f32,
|
||||||
max: f32,
|
max: f32,
|
||||||
length: usize,
|
length: usize,
|
||||||
|
seed: &Buffer,
|
||||||
buffer: &Buffer,
|
buffer: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
if min >= max {
|
if min >= max {
|
||||||
@ -1607,8 +1607,10 @@ pub fn call_random_uniform(
|
|||||||
encoder.wait_for_fence(&kernels.fence);
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, seed, min, max, buffer));
|
set_params!(encoder, (length, min, max, seed, buffer));
|
||||||
|
|
||||||
|
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
@ -1623,10 +1625,10 @@ pub fn call_random_normal(
|
|||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
seed: u64,
|
|
||||||
mean: f32,
|
mean: f32,
|
||||||
stddev: f32,
|
stddev: f32,
|
||||||
length: usize,
|
length: usize,
|
||||||
|
seed: &Buffer,
|
||||||
buffer: &Buffer,
|
buffer: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||||
@ -1638,8 +1640,10 @@ pub fn call_random_normal(
|
|||||||
encoder.wait_for_fence(&kernels.fence);
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, seed, mean, stddev, buffer));
|
set_params!(encoder, (length, mean, stddev, seed, buffer));
|
||||||
|
|
||||||
|
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
#include <metal_integer>
|
||||||
|
#include <metal_atomic>
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
// Constants
|
// Constants
|
||||||
@ -107,72 +110,85 @@ struct HybridTaus {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
METAL_FUNC float absdiff(float x, float y) {
|
||||||
|
return abs(x - y);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T> METAL_FUNC void rand_uniform(
|
template<typename T> METAL_FUNC void rand_uniform(
|
||||||
constant size_t &size,
|
constant size_t &size,
|
||||||
constant ulong &seed,
|
|
||||||
constant float &min,
|
constant float &min,
|
||||||
constant float &max,
|
constant float &max,
|
||||||
|
device atomic_uint *seed,
|
||||||
device T *out,
|
device T *out,
|
||||||
uint tid [[thread_position_in_grid]]
|
uint tid [[thread_position_in_grid]]
|
||||||
) {
|
) {
|
||||||
if (tid >= size) {
|
if (tid >= size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
float diff = max - min;
|
|
||||||
HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
|
float diff = absdiff(min, max);
|
||||||
|
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
||||||
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create Gaussian normal distribution using Box-Muller transform:
|
// Create Gaussian normal distribution using Box-Muller transform:
|
||||||
// https://en.wikipedia.org/wiki/Box–Muller_transform
|
// https://en.wikipedia.org/wiki/Box–Muller_transform
|
||||||
template<typename T> METAL_FUNC void normal(
|
template<typename T> METAL_FUNC void normal(
|
||||||
constant size_t &size,
|
constant size_t &size,
|
||||||
constant ulong &seed,
|
|
||||||
constant float &mean,
|
constant float &mean,
|
||||||
constant float &stddev,
|
constant float &stddev,
|
||||||
|
device atomic_uint *seed,
|
||||||
device T *out,
|
device T *out,
|
||||||
uint tid [[thread_position_in_grid]]
|
uint tid [[thread_position_in_grid]]
|
||||||
) {
|
) {
|
||||||
if (tid >= size) {
|
if (tid >= size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
HybridTaus rng = HybridTaus::init({seed, tid, 1, 1});
|
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||||
float u1 = rng.rand();
|
float u1 = rng.rand();
|
||||||
float u2 = rng.rand();
|
float u2 = rng.rand();
|
||||||
|
|
||||||
float cosval;
|
float cosval;
|
||||||
float sinval = sincos(u1 * TWO_PI, cosval);
|
float sinval = sincos(TWO_PI * u2, cosval);
|
||||||
float mag = stddev * sqrt(-2.0 * log(u1));
|
float mag = stddev * sqrt(-2.0 * log(u1));
|
||||||
float z0 = mag * cosval + mean;
|
float z0 = mag * cosval + mean;
|
||||||
float z1 = mag * sinval + mean;
|
float z1 = mag * sinval + mean;
|
||||||
|
|
||||||
out[tid] = static_cast<T>(z0);
|
out[tid] = static_cast<T>(z0);
|
||||||
out[size - tid] = static_cast<T>(z1);
|
out[size - tid] = static_cast<T>(z1);
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define UNIFORM_OP(NAME, T) \
|
#define UNIFORM_OP(NAME, T) \
|
||||||
kernel void rand_uniform_##NAME( \
|
kernel void rand_uniform_##NAME( \
|
||||||
constant size_t &size, \
|
constant size_t &size, \
|
||||||
constant ulong &seed, \
|
|
||||||
constant float &min, \
|
constant float &min, \
|
||||||
constant float &max, \
|
constant float &max, \
|
||||||
|
device atomic_uint *seed, \
|
||||||
device T *out, \
|
device T *out, \
|
||||||
uint tid [[thread_position_in_grid]] \
|
uint tid [[thread_position_in_grid]] \
|
||||||
) { \
|
) { \
|
||||||
rand_uniform<T>(size, seed, min, max, out, tid); \
|
rand_uniform<T>(size, min, max, seed, out, tid); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
#define NORMAL_OP(NAME, T) \
|
#define NORMAL_OP(NAME, T) \
|
||||||
kernel void rand_normal_##NAME( \
|
kernel void rand_normal_##NAME( \
|
||||||
constant size_t &size, \
|
constant size_t &size, \
|
||||||
constant ulong &seed, \
|
|
||||||
constant float &mean, \
|
constant float &mean, \
|
||||||
constant float &stddev, \
|
constant float &stddev, \
|
||||||
|
device atomic_uint *seed, \
|
||||||
device T *out, \
|
device T *out, \
|
||||||
uint tid [[thread_position_in_grid]] \
|
uint tid [[thread_position_in_grid]] \
|
||||||
) { \
|
) { \
|
||||||
normal<T>(size, seed, mean, stddev, out, tid); \
|
normal<T>(size, mean, stddev, seed, out, tid); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
|
||||||
|
@ -938,14 +938,21 @@ fn gemm() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> {
|
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let fence = device.new_fence();
|
let fence = device.new_fence();
|
||||||
let kernels = Kernels::new(fence);
|
let kernels = Kernels::new(fence);
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
|
||||||
|
|
||||||
|
let seed = device.new_buffer_with_data(
|
||||||
|
&seed as *const u32 as *const core::ffi::c_void,
|
||||||
|
std::mem::size_of::<u32>() as NSUInteger,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
|
||||||
if name.starts_with("rand_uniform") {
|
if name.starts_with("rand_uniform") {
|
||||||
call_random_uniform(
|
call_random_uniform(
|
||||||
@ -953,10 +960,10 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
|
|||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
name,
|
name,
|
||||||
seed,
|
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
length,
|
length,
|
||||||
|
&seed,
|
||||||
&output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -966,15 +973,14 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
|
|||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
name,
|
name,
|
||||||
seed,
|
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
length,
|
length,
|
||||||
|
&seed,
|
||||||
&output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
@ -1029,7 +1035,9 @@ fn random() {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(f32::from)
|
.map(f32::from)
|
||||||
.collect();
|
.collect();
|
||||||
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
|
results.iter().for_each(|v| {
|
||||||
|
assert!(*v >= min && *v <= max);
|
||||||
|
});
|
||||||
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
||||||
|
|
||||||
let results: Vec<f32> = run_random::<$type>(
|
let results: Vec<f32> = run_random::<$type>(
|
||||||
|
Reference in New Issue
Block a user