Fix use resource.

This commit is contained in:
Nicolas Patry
2023-12-14 16:52:37 +01:00
parent 361f2ad2af
commit f419a38e1a

View File

@ -312,6 +312,8 @@ pub fn call_unary_contiguous(
set_params!(encoder, (length, input, output)); set_params!(encoder, (length, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -354,6 +356,8 @@ pub fn call_unary_strided(
let width: usize = shape.iter().product(); let width: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -381,6 +385,9 @@ pub fn call_binary_contiguous(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(left, metal::MTLResourceUsage::Read);
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -428,6 +435,9 @@ pub fn call_binary_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.use_resource(left_input, metal::MTLResourceUsage::Read);
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -454,6 +464,8 @@ pub fn call_cast_contiguous(
set_params!(encoder, (length, (input, input_offset), output)); set_params!(encoder, (length, (input, input_offset), output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -494,6 +506,8 @@ pub fn call_cast_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -541,6 +555,8 @@ pub fn call_reduce_contiguous(
depth: 1, depth: 1,
}; };
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -585,6 +601,8 @@ pub fn call_last_softmax(
depth: 1, depth: 1,
}; };
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -612,6 +630,8 @@ pub fn call_affine(
set_params!(encoder, (size, mul, add, input, output)); set_params!(encoder, (size, mul, add, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -654,6 +674,8 @@ pub fn call_affine_strided(
); );
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -680,6 +702,8 @@ pub fn call_powf(
set_params!(encoder, (size, mul, input, output)); set_params!(encoder, (size, mul, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -720,6 +744,8 @@ pub fn call_powf_strided(
); );
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -746,6 +772,8 @@ pub fn call_elu(
set_params!(encoder, (size, mul, input, output)); set_params!(encoder, (size, mul, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -786,6 +814,8 @@ pub fn call_elu_strided(
); );
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -833,6 +863,10 @@ pub fn call_where_cond_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(cond, metal::MTLResourceUsage::Read);
encoder.use_resource(left, metal::MTLResourceUsage::Read);
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -880,6 +914,9 @@ pub fn call_index_select(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
@ -1121,6 +1158,9 @@ pub fn call_gemm(
depth: 1, depth: 1,
}; };
// println!("grid size {grid_size:?} group size {group_size:?}"); // println!("grid size {grid_size:?} group size {group_size:?}");
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size); encoder.dispatch_thread_groups(grid_size, group_size);
encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();