Seed should be updated by random kernel result.

This commit is contained in:
Ivar Flakstad
2024-01-14 18:10:54 +01:00
parent ecf88a6d38
commit 79478ff5a1
4 changed files with 76 additions and 27 deletions

View File

@ -938,14 +938,21 @@ fn gemm() {
);
}
fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec<T> {
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
let seed = device.new_buffer_with_data(
&seed as *const u32 as *const core::ffi::c_void,
std::mem::size_of::<u32>() as NSUInteger,
options,
);
if name.starts_with("rand_uniform") {
call_random_uniform(
@ -953,10 +960,10 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
command_buffer,
&kernels,
name,
seed,
a,
b,
length,
&seed,
&output,
)
.unwrap();
@ -966,15 +973,14 @@ fn run_random<T: Clone>(name: &'static str, seed: u64, length: usize, a: f32, b:
command_buffer,
&kernels,
name,
seed,
a,
b,
length,
&seed,
&output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
@ -1029,7 +1035,9 @@ fn random() {
.into_iter()
.map(f32::from)
.collect();
results.iter().for_each(|v| assert!(*v >= min && *v <= max));
results.iter().for_each(|v| {
assert!(*v >= min && *v <= max);
});
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
let results: Vec<f32> = run_random::<$type>(