mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix use resource.
This commit is contained in:
@ -312,6 +312,8 @@ pub fn call_unary_contiguous(
|
|||||||
set_params!(encoder, (length, input, output));
|
set_params!(encoder, (length, input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -354,6 +356,8 @@ pub fn call_unary_strided(
|
|||||||
let width: usize = shape.iter().product();
|
let width: usize = shape.iter().product();
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||||
|
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -381,6 +385,9 @@ pub fn call_binary_contiguous(
|
|||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
|
encoder.use_resource(left, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -428,6 +435,9 @@ pub fn call_binary_strided(
|
|||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||||
|
|
||||||
|
encoder.use_resource(left_input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -454,6 +464,8 @@ pub fn call_cast_contiguous(
|
|||||||
set_params!(encoder, (length, (input, input_offset), output));
|
set_params!(encoder, (length, (input, input_offset), output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -494,6 +506,8 @@ pub fn call_cast_strided(
|
|||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -541,6 +555,8 @@ pub fn call_reduce_contiguous(
|
|||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -585,6 +601,8 @@ pub fn call_last_softmax(
|
|||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -612,6 +630,8 @@ pub fn call_affine(
|
|||||||
set_params!(encoder, (size, mul, add, input, output));
|
set_params!(encoder, (size, mul, add, input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -654,6 +674,8 @@ pub fn call_affine_strided(
|
|||||||
);
|
);
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -680,6 +702,8 @@ pub fn call_powf(
|
|||||||
set_params!(encoder, (size, mul, input, output));
|
set_params!(encoder, (size, mul, input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -720,6 +744,8 @@ pub fn call_powf_strided(
|
|||||||
);
|
);
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -746,6 +772,8 @@ pub fn call_elu(
|
|||||||
set_params!(encoder, (size, mul, input, output));
|
set_params!(encoder, (size, mul, input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -786,6 +814,8 @@ pub fn call_elu_strided(
|
|||||||
);
|
);
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -833,6 +863,10 @@ pub fn call_where_cond_strided(
|
|||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
|
|
||||||
|
encoder.use_resource(cond, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(left, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -880,6 +914,9 @@ pub fn call_index_select(
|
|||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
|
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -1121,6 +1158,9 @@ pub fn call_gemm(
|
|||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
// println!("grid size {grid_size:?} group size {group_size:?}");
|
// println!("grid size {grid_size:?} group size {group_size:?}");
|
||||||
|
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
encoder.update_fence(&kernels.fence);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
Reference in New Issue
Block a user