Use a trait for the encoder provider (so that encoder can ultimately be reused). (#2352)

This commit is contained in:
Laurent Mazare
2024-07-24 08:27:30 +01:00
committed by GitHub
parent 6056fd5c90
commit a925ae6bc6
2 changed files with 143 additions and 120 deletions

View File

@ -1,6 +1,6 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function,
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
Buffer, CompileOptions, ComputePipelineState, Device, Function, FunctionConstantValues,
Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
@ -8,7 +8,7 @@ use std::sync::RwLock;
mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split};
use utils::{get_block_dims, linear_split, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
@ -297,7 +297,7 @@ impl Kernels {
#[allow(clippy::too_many_arguments)]
pub fn call_copy2d(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: copy2d::Kernel,
input: &Buffer,
@ -310,7 +310,7 @@ pub fn call_copy2d(
dst_o_in_bytes: usize,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@ -333,14 +333,14 @@ pub fn call_copy2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_threads(grid_dims, group_dims);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous_tiled(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous_tiled::Kernel,
length: usize,
@ -348,7 +348,7 @@ pub fn call_unary_contiguous_tiled(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
let tile_size = 2;
let tiles = (length + tile_size - 1) / tile_size;
@ -360,14 +360,14 @@ pub fn call_unary_contiguous_tiled(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
@ -375,7 +375,7 @@ pub fn call_unary_contiguous(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -385,14 +385,14 @@ pub fn call_unary_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: unary::strided::Kernel,
shape: &[usize],
@ -404,7 +404,7 @@ pub fn call_unary_strided(
let length: usize = shape.iter().product();
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.set_compute_pipeline_state(&pipeline);
@ -412,14 +412,14 @@ pub fn call_unary_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_binary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: binary::contiguous::Kernel,
length: usize,
@ -429,7 +429,7 @@ pub fn call_binary_contiguous(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, &left, &right, output));
@ -440,14 +440,14 @@ pub fn call_binary_contiguous(
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_binary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: binary::strided::Kernel,
shape: &[usize],
@ -460,7 +460,7 @@ pub fn call_binary_strided(
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
let width: usize = shape.iter().product();
let length: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
@ -483,7 +483,7 @@ pub fn call_binary_strided(
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
@ -491,7 +491,7 @@ pub fn call_binary_strided(
#[allow(clippy::too_many_arguments)]
pub fn call_cast_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
@ -500,7 +500,7 @@ pub fn call_cast_contiguous(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, &input, output));
@ -509,14 +509,14 @@ pub fn call_cast_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_cast_strided(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
@ -526,7 +526,7 @@ pub fn call_cast_strided(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@ -541,14 +541,14 @@ pub fn call_cast_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_reduce_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
@ -559,7 +559,7 @@ pub fn call_reduce_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, &input, output));
@ -585,14 +585,14 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_reduce_strided(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
@ -605,7 +605,7 @@ pub fn call_reduce_strided(
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -634,14 +634,14 @@ pub fn call_reduce_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_last_softmax(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
@ -651,7 +651,7 @@ pub fn call_last_softmax(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -682,14 +682,14 @@ pub fn call_last_softmax(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rms_norm(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
@ -702,7 +702,7 @@ pub fn call_rms_norm(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -741,14 +741,14 @@ pub fn call_rms_norm(
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_layer_norm(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
@ -763,7 +763,7 @@ pub fn call_layer_norm(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -803,14 +803,14 @@ pub fn call_layer_norm(
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope_i(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
bh: usize,
@ -824,7 +824,7 @@ pub fn call_rope_i(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -844,14 +844,14 @@ pub fn call_rope_i(
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope_thd(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
b: usize,
@ -867,7 +867,7 @@ pub fn call_rope_thd(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -889,14 +889,14 @@ pub fn call_rope_thd(
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
bh: usize,
@ -911,7 +911,7 @@ pub fn call_rope(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -932,14 +932,14 @@ pub fn call_rope(
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
size: usize,
@ -950,7 +950,7 @@ pub fn call_affine(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, &input, output));
@ -959,14 +959,14 @@ pub fn call_affine(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_affine_strided(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@ -979,7 +979,7 @@ pub fn call_affine_strided(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1000,14 +1000,14 @@ pub fn call_affine_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_powf(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
size: usize,
@ -1017,7 +1017,7 @@ pub fn call_powf(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, &input, output));
@ -1026,14 +1026,14 @@ pub fn call_powf(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_powf_strided(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@ -1045,7 +1045,7 @@ pub fn call_powf_strided(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1057,14 +1057,14 @@ pub fn call_powf_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_elu(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
size: usize,
@ -1074,7 +1074,7 @@ pub fn call_elu(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, &input, output));
@ -1083,14 +1083,14 @@ pub fn call_elu(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_elu_strided(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@ -1102,7 +1102,7 @@ pub fn call_elu_strided(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1114,14 +1114,14 @@ pub fn call_elu_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_where_cond_strided(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@ -1135,7 +1135,7 @@ pub fn call_where_cond_strided(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product();
@ -1164,14 +1164,14 @@ pub fn call_where_cond_strided(
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_index_select(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@ -1191,7 +1191,7 @@ pub fn call_index_select(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -1218,14 +1218,14 @@ pub fn call_index_select(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_gather(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@ -1242,7 +1242,7 @@ pub fn call_gather(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -1266,14 +1266,14 @@ pub fn call_gather(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_scatter_add(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
src_shape: &[usize],
@ -1291,7 +1291,7 @@ pub fn call_scatter_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -1315,14 +1315,14 @@ pub fn call_scatter_add(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_index_add(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
src_shape: &[usize],
@ -1341,7 +1341,7 @@ pub fn call_index_add(
let ids_dim_size = ids_shape[0];
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -1366,7 +1366,7 @@ pub fn call_index_add(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
@ -1453,7 +1453,7 @@ impl ConstantValues {
#[allow(clippy::too_many_arguments)]
pub fn call_gemm(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
@ -1572,7 +1572,7 @@ pub fn call_gemm(
};
let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
@ -1615,7 +1615,7 @@ pub fn call_gemm(
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
@ -1623,7 +1623,7 @@ pub fn call_gemm(
#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@ -1636,7 +1636,7 @@ pub fn call_im2col1d_strided(
let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
let dst_el = shape[0] * l_out * shape[1] * k_size;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1646,7 +1646,7 @@ pub fn call_im2col1d_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
@ -1654,7 +1654,7 @@ pub fn call_im2col1d_strided(
#[allow(clippy::too_many_arguments)]
pub fn call_col2im1d(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@ -1669,7 +1669,7 @@ pub fn call_col2im1d(
let l_out = (l_in - 1) * stride + k_size;
let dst_el = shape[0] * c_out * l_out;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1679,7 +1679,7 @@ pub fn call_col2im1d(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
@ -1687,7 +1687,7 @@ pub fn call_col2im1d(
#[allow(clippy::too_many_arguments)]
pub fn call_im2col_strided(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@ -1705,7 +1705,7 @@ pub fn call_im2col_strided(
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1718,7 +1718,7 @@ pub fn call_im2col_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
@ -1726,7 +1726,7 @@ pub fn call_im2col_strided(
#[allow(clippy::too_many_arguments)]
pub fn call_upsample_nearest_2d(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@ -1741,7 +1741,7 @@ pub fn call_upsample_nearest_2d(
let scale_w = shape[2] as f32 / out_w as f32;
let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@ -1750,7 +1750,7 @@ pub fn call_upsample_nearest_2d(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
@ -1758,7 +1758,7 @@ pub fn call_upsample_nearest_2d(
#[allow(clippy::too_many_arguments)]
pub fn call_random_uniform(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
min: f32,
@ -1773,7 +1773,7 @@ pub fn call_random_uniform(
));
}
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
@ -1788,7 +1788,7 @@ pub fn call_random_uniform(
);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
@ -1796,7 +1796,7 @@ pub fn call_random_uniform(
#[allow(clippy::too_many_arguments)]
pub fn call_random_normal(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
mean: f32,
@ -1806,7 +1806,7 @@ pub fn call_random_normal(
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
@ -1821,7 +1821,7 @@ pub fn call_random_normal(
);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
@ -1847,7 +1847,7 @@ pub enum GgmlDType {
#[allow(clippy::too_many_arguments)]
pub fn call_quantized_matmul_mv_t(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: GgmlDType,
(b, m, n, k): (usize, usize, usize, usize),
@ -1961,7 +1961,7 @@ pub fn call_quantized_matmul_mv_t(
};
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1993,7 +1993,7 @@ pub fn call_quantized_matmul_mv_t(
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
@ -2005,7 +2005,7 @@ fn divide(m: usize, b: usize) -> NSUInteger {
#[allow(clippy::too_many_arguments)]
pub fn call_pool2d(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@ -2022,7 +2022,7 @@ pub fn call_pool2d(
let dst_el = out_w * out_h * shape[0] * shape[1];
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@ -2031,14 +2031,14 @@ pub fn call_pool2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_conv_transpose1d(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
dilation: usize,
@ -2061,7 +2061,7 @@ pub fn call_conv_transpose1d(
let dst_el = c_out * l_out * b_size;
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@ -2084,7 +2084,7 @@ pub fn call_conv_transpose1d(
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
@ -2108,7 +2108,7 @@ pub struct CallConvTranspose2dCfg<'a> {
#[allow(clippy::too_many_arguments)]
pub fn call_conv_transpose2d(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
cfg: CallConvTranspose2dCfg,
@ -2119,7 +2119,7 @@ pub fn call_conv_transpose2d(
let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size;
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@ -2143,14 +2143,14 @@ pub fn call_conv_transpose2d(
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_arg_sort(
device: &Device,
command_buffer: &CommandBufferRef,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
nrows: usize,
@ -2160,7 +2160,7 @@ pub fn call_arg_sort(
dst: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
@ -2180,7 +2180,7 @@ pub fn call_arg_sort(
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
ep.maybe_end_encoding(encoder);
Ok(())
}

View File

@ -160,3 +160,26 @@ macro_rules! set_params {
)*
);
}
pub trait EncoderProvider {
fn encoder(&self) -> &ComputeCommandEncoderRef;
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef);
}
impl EncoderProvider for &metal::CommandBuffer {
fn encoder(&self) -> &ComputeCommandEncoderRef {
self.new_compute_command_encoder()
}
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
enc.end_encoding()
}
}
impl EncoderProvider for &metal::CommandBufferRef {
fn encoder(&self) -> &ComputeCommandEncoderRef {
self.new_compute_command_encoder()
}
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
enc.end_encoding()
}
}