mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Gaussian normal distribution of PRNG via Box-Muller transform
This commit is contained in:
@ -1415,7 +1415,6 @@ pub fn call_gemm(
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
// println!("grid size {grid_size:?} group size {group_size:?}");
|
||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
@ -1588,39 +1587,11 @@ pub fn call_random_uniform(
|
||||
"min must be less than max".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let size: usize = match name {
|
||||
"rand_uniform_f32" => 4,
|
||||
"rand_uniform_f16" | "rand_uniform_bf16" => 2,
|
||||
_ => Err(MetalKernelError::LoadLibraryError(format!(
|
||||
"{name} is not a valid kernel for random"
|
||||
)))?,
|
||||
};
|
||||
|
||||
let elems_per_key = length;
|
||||
let bytes_per_key = size * elems_per_key;
|
||||
|
||||
let out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
let half_size = out_per_key / 2;
|
||||
let odd = length % 2 != 0;
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: length as u64,
|
||||
height: half_size as u64 + odd as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let threads = std::cmp::min(
|
||||
(half_size + odd as usize) as NSUInteger,
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
);
|
||||
let thread_group_size = MTLSize {
|
||||
width: threads,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -1635,5 +1606,36 @@ pub fn call_random_uniform(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_random_normal(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
seed: u64,
|
||||
mean: f32,
|
||||
stddev: f32,
|
||||
length: usize,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, seed, mean, stddev, buffer));
|
||||
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
Reference in New Issue
Block a user