Seed should be updated by random kernel result.

This commit is contained in:
Ivar Flakstad
2024-01-14 18:10:54 +01:00
parent ecf88a6d38
commit 79478ff5a1
4 changed files with 76 additions and 27 deletions

View File

@ -1587,10 +1587,10 @@ pub fn call_random_uniform(
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
seed: u64,
min: f32,
max: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
if min >= max {
@ -1607,8 +1607,10 @@ pub fn call_random_uniform(
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, seed, min, max, buffer));
set_params!(encoder, (length, min, max, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
@ -1623,10 +1625,10 @@ pub fn call_random_normal(
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
seed: u64,
mean: f32,
stddev: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
@ -1638,8 +1640,10 @@ pub fn call_random_normal(
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, seed, mean, stddev, buffer));
set_params!(encoder, (length, mean, stddev, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);