mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Use RAII for terminating the encoding. (#2353)
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use metal::{
|
||||
Buffer, CompileOptions, ComputePipelineState, Device, Function, FunctionConstantValues,
|
||||
Library, MTLDataType, MTLSize, NSUInteger,
|
||||
Buffer, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
|
||||
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
@ -311,6 +311,7 @@ pub fn call_copy2d(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -333,7 +334,6 @@ pub fn call_copy2d(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_threads(grid_dims, group_dims);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -349,6 +349,7 @@ pub fn call_unary_contiguous_tiled(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let tile_size = 2;
|
||||
let tiles = (length + tile_size - 1) / tile_size;
|
||||
|
||||
@ -360,7 +361,6 @@ pub fn call_unary_contiguous_tiled(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -376,6 +376,7 @@ pub fn call_unary_contiguous(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -385,7 +386,6 @@ pub fn call_unary_contiguous(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -405,6 +405,7 @@ pub fn call_unary_strided(
|
||||
let length: usize = shape.iter().product();
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -412,7 +413,6 @@ pub fn call_unary_strided(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -430,6 +430,7 @@ pub fn call_binary_contiguous(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, &left, &right, output));
|
||||
@ -440,7 +441,6 @@ pub fn call_binary_contiguous(
|
||||
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -461,6 +461,7 @@ pub fn call_binary_strided(
|
||||
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let width: usize = shape.iter().product();
|
||||
let length: usize = shape.iter().product();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
@ -483,7 +484,6 @@ pub fn call_binary_strided(
|
||||
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -501,6 +501,7 @@ pub fn call_cast_contiguous(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, &input, output));
|
||||
@ -509,7 +510,6 @@ pub fn call_cast_contiguous(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -527,6 +527,7 @@ pub fn call_cast_strided(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -541,7 +542,6 @@ pub fn call_cast_strided(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -560,6 +560,7 @@ pub fn call_reduce_contiguous(
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, elements_to_sum, &input, output));
|
||||
@ -585,7 +586,6 @@ pub fn call_reduce_contiguous(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -606,6 +606,7 @@ pub fn call_reduce_strided(
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -634,7 +635,6 @@ pub fn call_reduce_strided(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -652,6 +652,7 @@ pub fn call_last_softmax(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -682,7 +683,6 @@ pub fn call_last_softmax(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -703,6 +703,7 @@ pub fn call_rms_norm(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -741,7 +742,6 @@ pub fn call_rms_norm(
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -764,6 +764,7 @@ pub fn call_layer_norm(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -803,7 +804,6 @@ pub fn call_layer_norm(
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -825,6 +825,7 @@ pub fn call_rope_i(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -844,7 +845,6 @@ pub fn call_rope_i(
|
||||
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -868,6 +868,7 @@ pub fn call_rope_thd(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -889,7 +890,6 @@ pub fn call_rope_thd(
|
||||
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -912,6 +912,7 @@ pub fn call_rope(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -932,7 +933,6 @@ pub fn call_rope(
|
||||
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -951,6 +951,7 @@ pub fn call_affine(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, add, &input, output));
|
||||
@ -959,7 +960,6 @@ pub fn call_affine(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -980,6 +980,7 @@ pub fn call_affine_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1000,7 +1001,6 @@ pub fn call_affine_strided(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1018,6 +1018,7 @@ pub fn call_powf(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, &input, output));
|
||||
@ -1026,7 +1027,6 @@ pub fn call_powf(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1046,6 +1046,7 @@ pub fn call_powf_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1057,7 +1058,6 @@ pub fn call_powf_strided(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1075,6 +1075,7 @@ pub fn call_elu(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, &input, output));
|
||||
@ -1083,7 +1084,6 @@ pub fn call_elu(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1103,6 +1103,7 @@ pub fn call_elu_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1114,7 +1115,6 @@ pub fn call_elu_strided(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1136,6 +1136,7 @@ pub fn call_where_cond_strided(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let size: usize = shape.iter().product();
|
||||
@ -1164,7 +1165,6 @@ pub fn call_where_cond_strided(
|
||||
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1192,6 +1192,7 @@ pub fn call_index_select(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -1218,7 +1219,6 @@ pub fn call_index_select(
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1243,6 +1243,7 @@ pub fn call_gather(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -1266,7 +1267,6 @@ pub fn call_gather(
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1292,6 +1292,7 @@ pub fn call_scatter_add(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -1315,7 +1316,6 @@ pub fn call_scatter_add(
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1342,6 +1342,7 @@ pub fn call_index_add(
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -1366,7 +1367,6 @@ pub fn call_index_add(
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1573,6 +1573,7 @@ pub fn call_gemm(
|
||||
let block_bytes = block_elements * bytes;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
@ -1615,8 +1616,6 @@ pub fn call_gemm(
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1637,6 +1636,7 @@ pub fn call_im2col1d_strided(
|
||||
let dst_el = shape[0] * l_out * shape[1] * k_size;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
@ -1646,8 +1646,6 @@ pub fn call_im2col1d_strided(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1670,6 +1668,7 @@ pub fn call_col2im1d(
|
||||
let dst_el = shape[0] * c_out * l_out;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
@ -1679,8 +1678,6 @@ pub fn call_col2im1d(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1706,6 +1703,7 @@ pub fn call_im2col_strided(
|
||||
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
@ -1718,8 +1716,6 @@ pub fn call_im2col_strided(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1742,6 +1738,7 @@ pub fn call_upsample_nearest_2d(
|
||||
let scale_h = shape[3] as f32 / out_h as f32;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -1750,8 +1747,6 @@ pub fn call_upsample_nearest_2d(
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1774,6 +1769,7 @@ pub fn call_random_uniform(
|
||||
}
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
@ -1788,8 +1784,6 @@ pub fn call_random_uniform(
|
||||
);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1807,6 +1801,7 @@ pub fn call_random_normal(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
@ -1821,8 +1816,6 @@ pub fn call_random_normal(
|
||||
);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1962,6 +1955,7 @@ pub fn call_quantized_matmul_mv_t(
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1993,8 +1987,6 @@ pub fn call_quantized_matmul_mv_t(
|
||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -2023,6 +2015,7 @@ pub fn call_pool2d(
|
||||
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -2031,7 +2024,6 @@ pub fn call_pool2d(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -2062,6 +2054,7 @@ pub fn call_conv_transpose1d(
|
||||
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -2084,7 +2077,6 @@ pub fn call_conv_transpose1d(
|
||||
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -2120,6 +2112,7 @@ pub fn call_conv_transpose2d(
|
||||
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -2143,7 +2136,6 @@ pub fn call_conv_transpose2d(
|
||||
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -2161,6 +2153,7 @@ pub fn call_arg_sort(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
|
||||
@ -2180,7 +2173,6 @@ pub fn call_arg_sort(
|
||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
ep.maybe_end_encoding(encoder);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -162,24 +162,40 @@ macro_rules! set_params {
|
||||
}
|
||||
|
||||
pub trait EncoderProvider {
|
||||
fn encoder(&self) -> &ComputeCommandEncoderRef;
|
||||
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef);
|
||||
type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
|
||||
}
|
||||
|
||||
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
|
||||
|
||||
impl<'a> Drop for WrappedEncoder<'a> {
|
||||
fn drop(&mut self) {
|
||||
self.0.end_encoding()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
|
||||
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderProvider for &metal::CommandBuffer {
|
||||
fn encoder(&self) -> &ComputeCommandEncoderRef {
|
||||
self.new_compute_command_encoder()
|
||||
}
|
||||
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
|
||||
enc.end_encoding()
|
||||
type Encoder<'a> = WrappedEncoder<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
|
||||
WrappedEncoder(self.new_compute_command_encoder())
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderProvider for &metal::CommandBufferRef {
|
||||
fn encoder(&self) -> &ComputeCommandEncoderRef {
|
||||
self.new_compute_command_encoder()
|
||||
}
|
||||
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
|
||||
enc.end_encoding()
|
||||
type Encoder<'a> = WrappedEncoder<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
|
||||
WrappedEncoder(self.new_compute_command_encoder())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user