mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Merge branch 'main' into ivarflakstad/metal-fill
This commit is contained in:
@ -2,8 +2,9 @@ mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::fill::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches
|
||||
);
|
||||
|
@ -1,6 +1,7 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod fill;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
63
candle-core/benches/benchmarks/random.rs
Normal file
63
candle-core/benches/benchmarks/random.rs
Normal file
@ -0,0 +1,63 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn rand_uniform(a: &Tensor) {
|
||||
a.rand_like(-1.0, 123.0).unwrap();
|
||||
}
|
||||
|
||||
fn rand_normal(a: &Tensor) {
|
||||
a.randn_like(100.0, 15.0).unwrap();
|
||||
}
|
||||
|
||||
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
|
||||
let rows = 2048;
|
||||
let cols = 2048;
|
||||
|
||||
let dtype = DType::F32;
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let flops = b * rows * cols * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_uniform(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_normal(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_random_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -8,8 +8,9 @@ use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, RwLock, TryLockError};
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
@ -102,6 +103,8 @@ pub struct MetalDevice {
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
||||
/// (strong_count = 1).
|
||||
buffers: AllocatedBuffers,
|
||||
/// Seed for random number generation.
|
||||
seed: Arc<Mutex<Buffer>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
@ -1555,6 +1558,11 @@ impl BackendDevice for MetalDevice {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 10,
|
||||
};
|
||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||
[299792458].as_ptr() as *const c_void,
|
||||
4,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
)));
|
||||
Ok(Self {
|
||||
device,
|
||||
command_queue,
|
||||
@ -1563,13 +1571,10 @@ impl BackendDevice for MetalDevice {
|
||||
compute_per_buffer,
|
||||
buffers,
|
||||
kernels,
|
||||
seed,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
crate::bail!("Metal set_seed not implemented")
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Metal {
|
||||
gpu_id: self.registry_id() as usize,
|
||||
@ -1618,7 +1623,7 @@ impl BackendDevice for MetalDevice {
|
||||
DType::F16 => fill!(f16::ONE),
|
||||
DType::F32 => fill!(1f32),
|
||||
DType::F64 => {
|
||||
return Err(MetalError::Message(format!("metal doesn't support double")).into())
|
||||
return Err(MetalError::Message("Metal doesn't support double".to_string()).into())
|
||||
}
|
||||
}
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
@ -1641,12 +1646,31 @@ impl BackendDevice for MetalDevice {
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
min: f64,
|
||||
max: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
let name = match dtype {
|
||||
DType::F32 => "rand_uniform_f32",
|
||||
DType::F16 => "rand_uniform_f16",
|
||||
DType::BF16 => "rand_uniform_bf16",
|
||||
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_random_uniform(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
min as f32,
|
||||
max as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn rand_normal(
|
||||
@ -1656,9 +1680,43 @@ impl BackendDevice for MetalDevice {
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
let name = match dtype {
|
||||
DType::F32 => "rand_normal_f32",
|
||||
DType::F16 => "rand_normal_f16",
|
||||
DType::BF16 => "rand_normal_bf16",
|
||||
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_random_normal(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
mean as f32,
|
||||
stddev as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
let seed: u32 = seed.try_into().map_err(|_| {
|
||||
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
|
||||
})?;
|
||||
|
||||
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
|
||||
let contents = seed_buffer.contents();
|
||||
unsafe {
|
||||
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
|
||||
}
|
||||
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14,8 +14,9 @@ const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const FILL: &str = include_str!("fill.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
|
||||
@ -48,9 +49,10 @@ fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64,
|
||||
/// Helper functions to create the various objects on the compute command encoder
|
||||
/// on a single line.
|
||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||
pub trait EncoderParam {
|
||||
pub trait EncoderParam: private::Sealed {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||
}
|
||||
|
||||
macro_rules! primitive {
|
||||
($type:ty) => {
|
||||
impl EncoderParam for $type {
|
||||
@ -64,14 +66,14 @@ macro_rules! primitive {
|
||||
}
|
||||
};
|
||||
}
|
||||
primitive!(usize);
|
||||
primitive!(u8);
|
||||
primitive!(u32);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(f16);
|
||||
primitive!(bf16);
|
||||
primitive!(f32);
|
||||
macro_rules! primitives {
|
||||
($($type:ty),+) => {
|
||||
$(
|
||||
primitive!($type);
|
||||
)+
|
||||
};
|
||||
}
|
||||
primitives!(bool, usize, u8, u32, u64, i32, i64, f16, bf16, f32);
|
||||
|
||||
impl<T> EncoderParam for &[T] {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
@ -114,6 +116,38 @@ macro_rules! set_params {
|
||||
);
|
||||
}
|
||||
|
||||
// Seal the trait so that only the types we want can implement it
|
||||
mod private {
|
||||
use super::*;
|
||||
|
||||
pub trait Sealed {}
|
||||
|
||||
macro_rules! sealed {
|
||||
($($type:ty),+) => {
|
||||
$(
|
||||
impl Sealed for $type {}
|
||||
)+
|
||||
};
|
||||
}
|
||||
sealed!(
|
||||
usize,
|
||||
u8,
|
||||
u32,
|
||||
u64,
|
||||
i32,
|
||||
i64,
|
||||
f16,
|
||||
bf16,
|
||||
f32,
|
||||
bool,
|
||||
&Buffer,
|
||||
(&Buffer, usize),
|
||||
&mut Buffer,
|
||||
(&mut Buffer, usize)
|
||||
);
|
||||
impl<T> Sealed for &[T] {}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Affine,
|
||||
@ -126,6 +160,7 @@ pub enum Source {
|
||||
Mfa,
|
||||
Conv,
|
||||
Fill,
|
||||
Random,
|
||||
Quantized,
|
||||
}
|
||||
|
||||
@ -250,6 +285,7 @@ impl Kernels {
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Fill => FILL,
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
@ -1536,6 +1572,73 @@ pub fn call_upsample_nearest_2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_random_uniform(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
min: f32,
|
||||
max: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
if min >= max {
|
||||
return Err(MetalKernelError::LoadLibraryError(
|
||||
"min must be less than max".to_string(),
|
||||
));
|
||||
}
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, min, max, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_random_normal(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
mean: f32,
|
||||
stddev: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, mean, stddev, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum GgmlDType {
|
||||
Q4_0,
|
||||
@ -1767,7 +1870,7 @@ macro_rules ! impl_call_fill {
|
||||
)*
|
||||
};
|
||||
}
|
||||
impl_call_fill!(u32, i64, f16, bf16, f32);
|
||||
impl_call_fill!(u8, u32, i64, f16, bf16, f32);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
206
candle-metal-kernels/src/random.metal
Normal file
206
candle-metal-kernels/src/random.metal
Normal file
@ -0,0 +1,206 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_integer>
|
||||
#include <metal_atomic>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// Constants
|
||||
// 2^32 and 1/2^32. Useful for converting between float and uint.
|
||||
static constexpr constant ulong UNIF01_NORM32 = 4294967296;
|
||||
static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10;
|
||||
// 2 * pi
|
||||
static constexpr constant float TWO_PI = 2.0 * M_PI_F;
|
||||
static constexpr constant int3 S1 = {13, 19, 12};
|
||||
static constexpr constant int3 S2 = {2, 25, 4};
|
||||
static constexpr constant int3 S3 = {3, 11, 17};
|
||||
|
||||
// Used to prevent bad seeds.
|
||||
static constexpr constant uint64_t PHI[16] = {
|
||||
0x9E3779B97F4A7C15,
|
||||
0xF39CC0605CEDC834,
|
||||
0x1082276BF3A27251,
|
||||
0xF86C6A11D0C18E95,
|
||||
0x2767F0B153D27B7F,
|
||||
0x0347045B5BF1827F,
|
||||
0x01886F0928403002,
|
||||
0xC1D64BA40F335E36,
|
||||
0xF06AD7AE9717877E,
|
||||
0x85839D6EFFBD7DC6,
|
||||
0x64D325D1C5371682,
|
||||
0xCADD0CCCFDFFBBE1,
|
||||
0x626E33B8D04B4331,
|
||||
0xBBF73C790D94F79D,
|
||||
0x471C4AB3ED3D82A5,
|
||||
0xFEC507705E4AE6E5,
|
||||
};
|
||||
|
||||
// Combined Tausworthe and LCG Random Number Generator.
|
||||
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
|
||||
// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
|
||||
struct HybridTaus {
|
||||
|
||||
float state;
|
||||
|
||||
HybridTaus() thread = default;
|
||||
HybridTaus() threadgroup = default;
|
||||
HybridTaus() device = default;
|
||||
HybridTaus() constant = default;
|
||||
|
||||
// Generate seeds for each thread.
|
||||
METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) {
|
||||
return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
|
||||
}
|
||||
|
||||
// Tausworthe generator.
|
||||
METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) {
|
||||
uint b = (((z << s.x) ^ z) >> s.y);
|
||||
return (((z & M) << s.z) ^ b);
|
||||
}
|
||||
|
||||
// LCG generator.
|
||||
METAL_FUNC static uint lcg(const uint z) {
|
||||
return (1664525 * z + 1013904223UL);
|
||||
}
|
||||
|
||||
// Initialize the RNG state.
|
||||
METAL_FUNC static HybridTaus init(const ulong4 seeds) {
|
||||
uint4 seed = seed_per_thread(seeds);
|
||||
|
||||
// Seed #1
|
||||
uint z1 = taus(seed.x, S1, 4294967294UL);
|
||||
uint z2 = taus(seed.y, S2, 4294967288UL);
|
||||
uint z3 = taus(seed.z, S3, 4294967280UL);
|
||||
uint z4 = lcg(seed.x);
|
||||
|
||||
// Seed #2
|
||||
uint r1 = (z1^z2^z3^z4^seed.y);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
// Seed #3
|
||||
r1 = (z1^z2^z3^z4^seed.z);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
// Seed #4
|
||||
r1 = (z1^z2^z3^z4^seed.w);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
HybridTaus rng;
|
||||
rng.state = (z1^z2^z3^z4) * UNIF01_INV32;
|
||||
return rng;
|
||||
}
|
||||
|
||||
METAL_FUNC float rand() {
|
||||
uint seed = this->state * UNIF01_NORM32;
|
||||
uint z1 = taus(seed, S1, 429496729UL);
|
||||
uint z2 = taus(seed, S2, 4294967288UL);
|
||||
uint z3 = taus(seed, S3, 429496280UL);
|
||||
uint z4 = lcg(seed);
|
||||
|
||||
thread float result = this->state;
|
||||
this->state = (z1^z2^z3^z4) * UNIF01_INV32;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T> METAL_FUNC void rand_uniform(
|
||||
constant size_t &size,
|
||||
constant float &min,
|
||||
constant float &max,
|
||||
device atomic_uint *seed,
|
||||
device T *out,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
float diff = abs(min - max);
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
}
|
||||
|
||||
// Create Gaussian normal distribution using Box-Muller transform:
|
||||
// https://en.wikipedia.org/wiki/Box–Muller_transform
|
||||
template<typename T> METAL_FUNC void normal(
|
||||
constant size_t &size,
|
||||
constant float &mean,
|
||||
constant float &stddev,
|
||||
device atomic_uint *seed,
|
||||
device T *out,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
float u1 = rng.rand();
|
||||
float u2 = rng.rand();
|
||||
|
||||
float cosval;
|
||||
float sinval = sincos(TWO_PI * u2, cosval);
|
||||
float mag = stddev * sqrt(-2.0 * log(u1));
|
||||
float z0 = mag * cosval + mean;
|
||||
float z1 = mag * sinval + mean;
|
||||
|
||||
out[tid] = static_cast<T>(z0);
|
||||
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - tid] = static_cast<T>(z1);
|
||||
}
|
||||
|
||||
#define UNIFORM_OP(NAME, T) \
|
||||
kernel void rand_uniform_##NAME( \
|
||||
constant size_t &size, \
|
||||
constant float &min, \
|
||||
constant float &max, \
|
||||
device atomic_uint *seed, \
|
||||
device T *out, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
rand_uniform<T>(size, min, max, seed, out, tid); \
|
||||
} \
|
||||
|
||||
#define NORMAL_OP(NAME, T) \
|
||||
kernel void rand_normal_##NAME( \
|
||||
constant size_t &size, \
|
||||
constant float &mean, \
|
||||
constant float &stddev, \
|
||||
device atomic_uint *seed, \
|
||||
device T *out, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
normal<T>(size, mean, stddev, seed, out, tid); \
|
||||
} \
|
||||
|
||||
|
||||
#define RANDOM_OPS(NAME, T) \
|
||||
UNIFORM_OP(NAME, T) \
|
||||
NORMAL_OP(NAME, T) \
|
||||
|
||||
RANDOM_OPS(f32, float)
|
||||
RANDOM_OPS(f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
RANDOM_OPS(bf16, bfloat)
|
||||
#endif
|
@ -927,17 +927,13 @@ fn gemm() {
|
||||
);
|
||||
}
|
||||
|
||||
fn run_fill<T: EncoderParam + Clone>(elem_count: usize, value: T) -> Vec<T>
|
||||
where
|
||||
Unary<T>: FillOp<T>,
|
||||
{
|
||||
fn run_fill<T: FillOp + Clone>(elem_count: usize, value: T) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let buffer = new_buffer(&device, &vec![0.0f32; elem_count]);
|
||||
Unary::<T>::fill(
|
||||
call_fill(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
@ -954,10 +950,7 @@ where
|
||||
|
||||
#[test]
|
||||
fn fill() {
|
||||
fn assert_fill<T: EncoderParam + Copy + std::fmt::Debug + PartialEq>(value: T)
|
||||
where
|
||||
Unary<T>: FillOp<T>,
|
||||
{
|
||||
fn assert_fill<T: FillOp + Copy + std::fmt::Debug + PartialEq>(value: T) {
|
||||
for i in 0..4 {
|
||||
assert_eq!(run_fill(8 ^ i, value), vec![value; 8 ^ i]);
|
||||
}
|
||||
@ -969,3 +962,124 @@ fn fill() {
|
||||
assert_fill(bf16::from_f32(4.56));
|
||||
assert_fill(7.89f32);
|
||||
}
|
||||
|
||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
|
||||
|
||||
let seed = device.new_buffer_with_data(
|
||||
&seed as *const u32 as *const core::ffi::c_void,
|
||||
std::mem::size_of::<u32>() as NSUInteger,
|
||||
options,
|
||||
);
|
||||
|
||||
if name.starts_with("rand_uniform") {
|
||||
call_random_uniform(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
a,
|
||||
b,
|
||||
length,
|
||||
&seed,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
call_random_normal(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
a,
|
||||
b,
|
||||
length,
|
||||
&seed,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&output, length)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random() {
|
||||
fn calc_mean(data: &[f32]) -> f32 {
|
||||
let sum = data.iter().sum::<f32>() as f32;
|
||||
let count = data.len();
|
||||
assert!(count > 0);
|
||||
sum / count as f32
|
||||
}
|
||||
|
||||
fn calc_stddev(data: &[f32]) -> f32 {
|
||||
let mean = calc_mean(data);
|
||||
let count = data.len();
|
||||
assert!(count > 0);
|
||||
|
||||
let variance = data
|
||||
.iter()
|
||||
.map(|value| {
|
||||
let diff = mean - (*value as f32);
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
/ count as f32;
|
||||
|
||||
variance.sqrt()
|
||||
}
|
||||
|
||||
let shape = vec![1024, 10];
|
||||
|
||||
let length = shape.iter().product::<usize>();
|
||||
let seed = 299792458;
|
||||
|
||||
let min = -30.0;
|
||||
let max = 30.0;
|
||||
let mean = 100.0;
|
||||
let stddev = 50.0;
|
||||
|
||||
macro_rules! validate_random {
|
||||
($type:ty) => {
|
||||
let results: Vec<f32> = run_random::<$type>(
|
||||
concat!("rand_uniform_", stringify!($type)),
|
||||
seed,
|
||||
length,
|
||||
min,
|
||||
max,
|
||||
)
|
||||
.into_iter()
|
||||
.map(f32::from)
|
||||
.collect();
|
||||
results.iter().for_each(|v| {
|
||||
assert!(*v >= min && *v <= max);
|
||||
});
|
||||
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
||||
|
||||
let results: Vec<f32> = run_random::<$type>(
|
||||
concat!("rand_normal_", stringify!($type)),
|
||||
seed,
|
||||
length,
|
||||
mean,
|
||||
stddev,
|
||||
)
|
||||
.into_iter()
|
||||
.map(f32::from)
|
||||
.collect();
|
||||
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
|
||||
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
|
||||
};
|
||||
}
|
||||
|
||||
validate_random!(f32);
|
||||
validate_random!(f16);
|
||||
validate_random!(bf16);
|
||||
}
|
||||
|
Reference in New Issue
Block a user