Gaussian normal distribution of PRNG via Box-Muller transform

This commit is contained in:
Ivar Flakstad
2024-01-05 21:18:12 +01:00
parent 955e63c803
commit 6bf52b9fdf
5 changed files with 238 additions and 101 deletions

View File

@ -1415,7 +1415,6 @@ pub fn call_gemm(
height: 1,
depth: 1,
};
// println!("grid size {grid_size:?} group size {group_size:?}");
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
@ -1588,39 +1587,11 @@ pub fn call_random_uniform(
"min must be less than max".to_string(),
));
}
let size: usize = match name {
"rand_uniform_f32" => 4,
"rand_uniform_f16" | "rand_uniform_bf16" => 2,
_ => Err(MetalKernelError::LoadLibraryError(format!(
"{name} is not a valid kernel for random"
)))?,
};
let elems_per_key = length;
let bytes_per_key = size * elems_per_key;
let out_per_key = (bytes_per_key + 4 - 1) / 4;
let half_size = out_per_key / 2;
let odd = length % 2 != 0;
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let thread_group_count = MTLSize {
width: length as u64,
height: half_size as u64 + odd as u64,
depth: 1,
};
let threads = std::cmp::min(
(half_size + odd as usize) as NSUInteger,
pipeline.max_total_threads_per_threadgroup(),
);
let thread_group_size = MTLSize {
width: threads,
height: 1,
depth: 1,
};
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
@ -1635,5 +1606,36 @@ pub fn call_random_uniform(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_random_normal(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
seed: u64,
mean: f32,
stddev: f32,
length: usize,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, seed, mean, stddev, buffer));
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests;