diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index fd47b1b8..945f9dbb 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -2,8 +2,9 @@ mod benchmarks; use criterion::criterion_main; criterion_main!( - benchmarks::matmul::benches, benchmarks::affine::benches, benchmarks::fill::benches, - benchmarks::where_cond::benches, + benchmarks::matmul::benches, + benchmarks::random::benches, + benchmarks::where_cond::benches ); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index eb2397ed..a02186e2 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod affine; pub(crate) mod fill; pub(crate) mod matmul; +pub(crate) mod random; pub(crate) mod where_cond; use candle_core::{Device, Result}; diff --git a/candle-core/benches/benchmarks/random.rs b/candle-core/benches/benchmarks/random.rs new file mode 100644 index 00000000..22c60ef1 --- /dev/null +++ b/candle-core/benches/benchmarks/random.rs @@ -0,0 +1,63 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn rand_uniform(a: &Tensor) { + a.rand_like(-1.0, 123.0).unwrap(); +} + +fn rand_normal(a: &Tensor) { + a.randn_like(100.0, 15.0).unwrap(); +} + +fn run_random_bench(c: &mut Criterion, device: &Device) { + let b = 1; + + let rows = 2048; + let cols = 2048; + + let dtype = DType::F32; + let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap(); + + let flops = b * rows * cols * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name("random_uniform")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |benches| { + benches.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + rand_uniform(black_box(&tensor)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); + + let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap(); + + let mut group = c.benchmark_group(device.bench_name("random_normal")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |benches| { + benches.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + rand_normal(black_box(&tensor)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_random_bench(c, &device); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 3358c3ea..32cce3d1 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -8,8 +8,9 @@ use half::{bf16, f16}; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; +use std::ffi::c_void; use std::path::Path; -use std::sync::{Arc, RwLock, TryLockError}; +use std::sync::{Arc, Mutex, RwLock, TryLockError}; /// Simple way to catch lock error without /// depending on T @@ -102,6 +103,8 @@ pub struct MetalDevice { /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers /// (strong_count = 1). buffers: AllocatedBuffers, + /// Seed for random number generation. + seed: Arc>, } impl std::fmt::Debug for MetalDevice { @@ -226,7 +229,7 @@ impl MetalDevice { // The slice might not live long enough for metal // To actually fill the GPU buffer. // Putting this wait forces the GPU buffer to be filled - // with the actual data allowing the CPU storage todo + // with the actual data allowing the CPU storage to do // deallocate properly. self.wait_until_completed()?; Ok(real) @@ -1555,6 +1558,11 @@ impl BackendDevice for MetalDevice { Ok(val) => val.parse()?, _ => 10, }; + let seed = Arc::new(Mutex::new(device.new_buffer_with_data( + [299792458].as_ptr() as *const c_void, + 4, + MTLResourceOptions::StorageModeManaged, + ))); Ok(Self { device, command_queue, @@ -1563,13 +1571,10 @@ impl BackendDevice for MetalDevice { compute_per_buffer, buffers, kernels, + seed, }) } - fn set_seed(&self, _seed: u64) -> Result<()> { - crate::bail!("Metal set_seed not implemented") - } - fn location(&self) -> crate::DeviceLocation { crate::DeviceLocation::Metal { gpu_id: self.registry_id() as usize, @@ -1618,7 +1623,7 @@ impl BackendDevice for MetalDevice { DType::F16 => fill!(f16::ONE), DType::F32 => fill!(1f32), DType::F64 => { - return Err(MetalError::Message(format!("metal doesn't support double")).into()) + return Err(MetalError::Message("Metal doesn't support double".to_string()).into()) } } Ok(MetalStorage::new(buffer, self.clone(), dtype)) @@ -1641,12 +1646,31 @@ impl BackendDevice for MetalDevice { &self, shape: &Shape, dtype: DType, - mean: f64, - stddev: f64, + min: f64, + max: f64, ) -> Result { - // TODO is there a better way ? - let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?; - self.storage_from_cpu_storage(&cpu_storage) + let name = match dtype { + DType::F32 => "rand_uniform_f32", + DType::F16 => "rand_uniform_f16", + DType::BF16 => "rand_uniform_bf16", + dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"), + }; + let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?; + let command_buffer = self.command_buffer()?; + candle_metal_kernels::call_random_uniform( + &self.device, + &command_buffer, + &self.kernels, + name, + min as f32, + max as f32, + shape.elem_count(), + &*self.seed.lock().unwrap(), + &buffer, + ) + .map_err(MetalError::from)?; + + Ok(Self::Storage::new(buffer, self.clone(), dtype)) } fn rand_normal( @@ -1656,9 +1680,43 @@ impl BackendDevice for MetalDevice { mean: f64, stddev: f64, ) -> Result { - // TODO is there a better way ? - let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?; - self.storage_from_cpu_storage(&cpu_storage) + let name = match dtype { + DType::F32 => "rand_normal_f32", + DType::F16 => "rand_normal_f16", + DType::BF16 => "rand_normal_bf16", + dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"), + }; + let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?; + let command_buffer = self.command_buffer()?; + candle_metal_kernels::call_random_normal( + &self.device, + &command_buffer, + &self.kernels, + name, + mean as f32, + stddev as f32, + shape.elem_count(), + &*self.seed.lock().unwrap(), + &buffer, + ) + .map_err(MetalError::from)?; + + Ok(Self::Storage::new(buffer, self.clone(), dtype)) + } + + fn set_seed(&self, seed: u64) -> Result<()> { + let seed: u32 = seed.try_into().map_err(|_| { + MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string()) + })?; + + let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?; + let contents = seed_buffer.contents(); + unsafe { + std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4); + } + seed_buffer.did_modify_range(metal::NSRange::new(0, 4)); + + Ok(()) } } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 250a7d05..c06f168f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -14,8 +14,9 @@ const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const FILL: &str = include_str!("fill.metal"); -const REDUCE: &str = include_str!("reduce.metal"); const CONV: &str = include_str!("conv.metal"); +const REDUCE: &str = include_str!("reduce.metal"); +const RANDOM: &str = include_str!("random.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const QUANTIZED: &str = include_str!("quantized.metal"); @@ -48,9 +49,10 @@ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, /// Helper functions to create the various objects on the compute command encoder /// on a single line. /// Prevents getting wrong some arguments number and mixing length and size in bytes. -pub trait EncoderParam { +pub trait EncoderParam: private::Sealed { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } + macro_rules! primitive { ($type:ty) => { impl EncoderParam for $type { @@ -64,14 +66,14 @@ macro_rules! primitive { } }; } -primitive!(usize); -primitive!(u8); -primitive!(u32); -primitive!(i32); -primitive!(i64); -primitive!(f16); -primitive!(bf16); -primitive!(f32); +macro_rules! primitives { + ($($type:ty),+) => { + $( + primitive!($type); + )+ + }; +} +primitives!(bool, usize, u8, u32, u64, i32, i64, f16, bf16, f32); impl EncoderParam for &[T] { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { @@ -114,6 +116,38 @@ macro_rules! set_params { ); } +// Seal the trait so that only the types we want can implement it +mod private { + use super::*; + + pub trait Sealed {} + + macro_rules! sealed { + ($($type:ty),+) => { + $( + impl Sealed for $type {} + )+ + }; + } + sealed!( + usize, + u8, + u32, + u64, + i32, + i64, + f16, + bf16, + f32, + bool, + &Buffer, + (&Buffer, usize), + &mut Buffer, + (&mut Buffer, usize) + ); + impl Sealed for &[T] {} +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { Affine, @@ -126,6 +160,7 @@ pub enum Source { Mfa, Conv, Fill, + Random, Quantized, } @@ -250,6 +285,7 @@ impl Kernels { Source::Reduce => REDUCE, Source::Fill => FILL, Source::Conv => CONV, + Source::Random => RANDOM, Source::Quantized => QUANTIZED, Source::Mfa => panic!("Invalid lib"), } @@ -1536,6 +1572,73 @@ pub fn call_upsample_nearest_2d( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_random_uniform( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + min: f32, + max: f32, + length: usize, + seed: &Buffer, + buffer: &Buffer, +) -> Result<(), MetalKernelError> { + if min >= max { + return Err(MetalKernelError::LoadLibraryError( + "min must be less than max".to_string(), + )); + } + let pipeline = kernels.load_pipeline(device, Source::Random, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + + let odd = (length % 2 != 0) as usize; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, min, max, seed, buffer)); + + encoder.use_resource(seed, metal::MTLResourceUsage::Read); + encoder.use_resource(seed, metal::MTLResourceUsage::Write); + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_random_normal( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + mean: f32, + stddev: f32, + length: usize, + seed: &Buffer, + buffer: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Random, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + + let odd = (length % 2 != 0) as usize; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, mean, stddev, seed, buffer)); + + encoder.use_resource(seed, metal::MTLResourceUsage::Read); + encoder.use_resource(seed, metal::MTLResourceUsage::Write); + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + Ok(()) +} + #[derive(Debug, Clone, Copy)] pub enum GgmlDType { Q4_0, @@ -1767,7 +1870,7 @@ macro_rules ! impl_call_fill { )* }; } -impl_call_fill!(u32, i64, f16, bf16, f32); +impl_call_fill!(u8, u32, i64, f16, bf16, f32); #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal new file mode 100644 index 00000000..a7e48393 --- /dev/null +++ b/candle-metal-kernels/src/random.metal @@ -0,0 +1,206 @@ +#include +#include +#include + +using namespace metal; + +// Constants +// 2^32 and 1/2^32. Useful for converting between float and uint. +static constexpr constant ulong UNIF01_NORM32 = 4294967296; +static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10; +// 2 * pi +static constexpr constant float TWO_PI = 2.0 * M_PI_F; +static constexpr constant int3 S1 = {13, 19, 12}; +static constexpr constant int3 S2 = {2, 25, 4}; +static constexpr constant int3 S3 = {3, 11, 17}; + +// Used to prevent bad seeds. +static constexpr constant uint64_t PHI[16] = { + 0x9E3779B97F4A7C15, + 0xF39CC0605CEDC834, + 0x1082276BF3A27251, + 0xF86C6A11D0C18E95, + 0x2767F0B153D27B7F, + 0x0347045B5BF1827F, + 0x01886F0928403002, + 0xC1D64BA40F335E36, + 0xF06AD7AE9717877E, + 0x85839D6EFFBD7DC6, + 0x64D325D1C5371682, + 0xCADD0CCCFDFFBBE1, + 0x626E33B8D04B4331, + 0xBBF73C790D94F79D, + 0x471C4AB3ED3D82A5, + 0xFEC507705E4AE6E5, +}; + +// Combined Tausworthe and LCG Random Number Generator. +// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application +// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf +struct HybridTaus { + + float state; + + HybridTaus() thread = default; + HybridTaus() threadgroup = default; + HybridTaus() device = default; + HybridTaus() constant = default; + + // Generate seeds for each thread. + METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) { + return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL)); + } + + // Tausworthe generator. + METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) { + uint b = (((z << s.x) ^ z) >> s.y); + return (((z & M) << s.z) ^ b); + } + + // LCG generator. + METAL_FUNC static uint lcg(const uint z) { + return (1664525 * z + 1013904223UL); + } + + // Initialize the RNG state. + METAL_FUNC static HybridTaus init(const ulong4 seeds) { + uint4 seed = seed_per_thread(seeds); + + // Seed #1 + uint z1 = taus(seed.x, S1, 4294967294UL); + uint z2 = taus(seed.y, S2, 4294967288UL); + uint z3 = taus(seed.z, S3, 4294967280UL); + uint z4 = lcg(seed.x); + + // Seed #2 + uint r1 = (z1^z2^z3^z4^seed.y); + z1 = taus(r1, S1, 429496729UL); + z2 = taus(r1, S2, 4294967288UL); + z3 = taus(r1, S3, 429496280UL); + z4 = lcg(r1); + + // Seed #3 + r1 = (z1^z2^z3^z4^seed.z); + z1 = taus(r1, S1, 429496729UL); + z2 = taus(r1, S2, 4294967288UL); + z3 = taus(r1, S3, 429496280UL); + z4 = lcg(r1); + + // Seed #4 + r1 = (z1^z2^z3^z4^seed.w); + z1 = taus(r1, S1, 429496729UL); + z2 = taus(r1, S2, 4294967288UL); + z3 = taus(r1, S3, 429496280UL); + z4 = lcg(r1); + + HybridTaus rng; + rng.state = (z1^z2^z3^z4) * UNIF01_INV32; + return rng; + } + + METAL_FUNC float rand() { + uint seed = this->state * UNIF01_NORM32; + uint z1 = taus(seed, S1, 429496729UL); + uint z2 = taus(seed, S2, 4294967288UL); + uint z3 = taus(seed, S3, 429496280UL); + uint z4 = lcg(seed); + + thread float result = this->state; + this->state = (z1^z2^z3^z4) * UNIF01_INV32; + return result; + } +}; + +template METAL_FUNC void rand_uniform( + constant size_t &size, + constant float &min, + constant float &max, + device atomic_uint *seed, + device T *out, + uint tid [[thread_position_in_grid]] +) { + if (tid >= size) { + return; + } + + float diff = abs(min - max); + HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); + out[tid] = static_cast(rng.rand() * diff + min); + if (tid == 0) { + atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + // Return early if tid == 0, otherwise we will write to out[size]. + return; + } + // Use symmetry to fill the other half of the array. + out[size - tid] = static_cast(rng.rand() * diff + min); +} + +// Create Gaussian normal distribution using Box-Muller transform: +// https://en.wikipedia.org/wiki/Box–Muller_transform +template METAL_FUNC void normal( + constant size_t &size, + constant float &mean, + constant float &stddev, + device atomic_uint *seed, + device T *out, + uint tid [[thread_position_in_grid]] +) { + if (tid >= size) { + return; + } + HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); + float u1 = rng.rand(); + float u2 = rng.rand(); + + float cosval; + float sinval = sincos(TWO_PI * u2, cosval); + float mag = stddev * sqrt(-2.0 * log(u1)); + float z0 = mag * cosval + mean; + float z1 = mag * sinval + mean; + + out[tid] = static_cast(z0); + + if (tid == 0) { + atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + // Return early if tid == 0, otherwise we will write to out[size]. + return; + } + // Use symmetry to fill the other half of the array. + out[size - tid] = static_cast(z1); +} + +#define UNIFORM_OP(NAME, T) \ +kernel void rand_uniform_##NAME( \ + constant size_t &size, \ + constant float &min, \ + constant float &max, \ + device atomic_uint *seed, \ + device T *out, \ + uint tid [[thread_position_in_grid]] \ +) { \ + rand_uniform(size, min, max, seed, out, tid); \ +} \ + +#define NORMAL_OP(NAME, T) \ +kernel void rand_normal_##NAME( \ + constant size_t &size, \ + constant float &mean, \ + constant float &stddev, \ + device atomic_uint *seed, \ + device T *out, \ + uint tid [[thread_position_in_grid]] \ +) { \ + normal(size, mean, stddev, seed, out, tid); \ +} \ + + +#define RANDOM_OPS(NAME, T) \ +UNIFORM_OP(NAME, T) \ +NORMAL_OP(NAME, T) \ + +RANDOM_OPS(f32, float) +RANDOM_OPS(f16, half) + +#if __METAL_VERSION__ >= 310 +RANDOM_OPS(bf16, bfloat) +#endif diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 79b2e85d..be6639ef 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -927,17 +927,13 @@ fn gemm() { ); } -fn run_fill(elem_count: usize, value: T) -> Vec -where - Unary: FillOp, -{ +fn run_fill(elem_count: usize, value: T) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let buffer = new_buffer(&device, &vec![0.0f32; elem_count]); - Unary::::fill( + call_fill( &device, command_buffer, &kernels, @@ -954,10 +950,7 @@ where #[test] fn fill() { - fn assert_fill(value: T) - where - Unary: FillOp, - { + fn assert_fill(value: T) { for i in 0..4 { assert_eq!(run_fill(8 ^ i, value), vec![value; 8 ^ i]); } @@ -969,3 +962,124 @@ fn fill() { assert_fill(bf16::from_f32(4.56)); assert_fill(7.89f32); } + +fn run_random(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let options = MTLResourceOptions::StorageModeManaged; + let output = device.new_buffer((length * core::mem::size_of::()) as NSUInteger, options); + + let seed = device.new_buffer_with_data( + &seed as *const u32 as *const core::ffi::c_void, + std::mem::size_of::() as NSUInteger, + options, + ); + + if name.starts_with("rand_uniform") { + call_random_uniform( + &device, + command_buffer, + &kernels, + name, + a, + b, + length, + &seed, + &output, + ) + .unwrap(); + } else { + call_random_normal( + &device, + command_buffer, + &kernels, + name, + a, + b, + length, + &seed, + &output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, length) +} + +#[test] +fn random() { + fn calc_mean(data: &[f32]) -> f32 { + let sum = data.iter().sum::() as f32; + let count = data.len(); + assert!(count > 0); + sum / count as f32 + } + + fn calc_stddev(data: &[f32]) -> f32 { + let mean = calc_mean(data); + let count = data.len(); + assert!(count > 0); + + let variance = data + .iter() + .map(|value| { + let diff = mean - (*value as f32); + diff * diff + }) + .sum::() + / count as f32; + + variance.sqrt() + } + + let shape = vec![1024, 10]; + + let length = shape.iter().product::(); + let seed = 299792458; + + let min = -30.0; + let max = 30.0; + let mean = 100.0; + let stddev = 50.0; + + macro_rules! validate_random { + ($type:ty) => { + let results: Vec = run_random::<$type>( + concat!("rand_uniform_", stringify!($type)), + seed, + length, + min, + max, + ) + .into_iter() + .map(f32::from) + .collect(); + results.iter().for_each(|v| { + assert!(*v >= min && *v <= max); + }); + assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); + + let results: Vec = run_random::<$type>( + concat!("rand_normal_", stringify!($type)), + seed, + length, + mean, + stddev, + ) + .into_iter() + .map(f32::from) + .collect(); + assert!((calc_mean(&results) - mean).abs() < mean / 10.0); + assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0); + }; + } + + validate_random!(f32); + validate_random!(f16); + validate_random!(bf16); +}