From f419a38e1ad431cac245e0d7525b2c278660df18 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 14 Dec 2023 16:52:37 +0100 Subject: [PATCH] Fix use resource. --- candle-metal-kernels/src/lib.rs | 40 +++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 01432ccb..0c383dec 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -312,6 +312,8 @@ pub fn call_unary_contiguous( set_params!(encoder, (length, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -354,6 +356,8 @@ pub fn call_unary_strided( let width: usize = shape.iter().product(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -381,6 +385,9 @@ pub fn call_binary_contiguous( let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(left, metal::MTLResourceUsage::Read); + encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -428,6 +435,9 @@ pub fn call_binary_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.use_resource(left_input, metal::MTLResourceUsage::Read); + encoder.use_resource(right_input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -454,6 +464,8 @@ pub fn call_cast_contiguous( set_params!(encoder, (length, (input, input_offset), output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -494,6 +506,8 @@ pub fn call_cast_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -541,6 +555,8 @@ pub fn call_reduce_contiguous( depth: 1, }; + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -585,6 +601,8 @@ pub fn call_last_softmax( depth: 1, }; + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -612,6 +630,8 @@ pub fn call_affine( set_params!(encoder, (size, mul, add, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -654,6 +674,8 @@ pub fn call_affine_strided( ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -680,6 +702,8 @@ pub fn call_powf( set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -720,6 +744,8 @@ pub fn call_powf_strided( ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -746,6 +772,8 @@ pub fn call_elu( set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -786,6 +814,8 @@ pub fn call_elu_strided( ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -833,6 +863,10 @@ pub fn call_where_cond_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(cond, metal::MTLResourceUsage::Read); + encoder.use_resource(left, metal::MTLResourceUsage::Read); + encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -880,6 +914,9 @@ pub fn call_index_select( let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); @@ -1121,6 +1158,9 @@ pub fn call_gemm( depth: 1, }; // println!("grid size {grid_size:?} group size {group_size:?}"); + encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_size, group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding();