mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Seed should be updated by random kernel result.
This commit is contained in:
@ -1587,10 +1587,10 @@ pub fn call_random_uniform(
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
seed: u64,
|
||||
min: f32,
|
||||
max: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
if min >= max {
|
||||
@ -1607,8 +1607,10 @@ pub fn call_random_uniform(
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, seed, min, max, buffer));
|
||||
set_params!(encoder, (length, min, max, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
@ -1623,10 +1625,10 @@ pub fn call_random_normal(
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
seed: u64,
|
||||
mean: f32,
|
||||
stddev: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
@ -1638,8 +1640,10 @@ pub fn call_random_normal(
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, seed, mean, stddev, buffer));
|
||||
set_params!(encoder, (length, mean, stddev, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
|
Reference in New Issue
Block a user