mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Use a trait for the encoder provider (so that encoder can ultimately be reused). (#2352)
This commit is contained in:
@ -1,6 +1,6 @@
|
|||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function,
|
Buffer, CompileOptions, ComputePipelineState, Device, Function, FunctionConstantValues,
|
||||||
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
Library, MTLDataType, MTLSize, NSUInteger,
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
@ -8,7 +8,7 @@ use std::sync::RwLock;
|
|||||||
|
|
||||||
mod utils;
|
mod utils;
|
||||||
pub use utils::BufferOffset;
|
pub use utils::BufferOffset;
|
||||||
use utils::{get_block_dims, linear_split};
|
use utils::{get_block_dims, linear_split, EncoderProvider};
|
||||||
|
|
||||||
const AFFINE: &str = include_str!("affine.metal");
|
const AFFINE: &str = include_str!("affine.metal");
|
||||||
const INDEXING: &str = include_str!("indexing.metal");
|
const INDEXING: &str = include_str!("indexing.metal");
|
||||||
@ -297,7 +297,7 @@ impl Kernels {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_copy2d(
|
pub fn call_copy2d(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: copy2d::Kernel,
|
name: copy2d::Kernel,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
@ -310,7 +310,7 @@ pub fn call_copy2d(
|
|||||||
dst_o_in_bytes: usize,
|
dst_o_in_bytes: usize,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
@ -333,14 +333,14 @@ pub fn call_copy2d(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_threads(grid_dims, group_dims);
|
encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_unary_contiguous_tiled(
|
pub fn call_unary_contiguous_tiled(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: unary::contiguous_tiled::Kernel,
|
kernel_name: unary::contiguous_tiled::Kernel,
|
||||||
length: usize,
|
length: usize,
|
||||||
@ -348,7 +348,7 @@ pub fn call_unary_contiguous_tiled(
|
|||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
let tile_size = 2;
|
let tile_size = 2;
|
||||||
let tiles = (length + tile_size - 1) / tile_size;
|
let tiles = (length + tile_size - 1) / tile_size;
|
||||||
|
|
||||||
@ -360,14 +360,14 @@ pub fn call_unary_contiguous_tiled(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_unary_contiguous(
|
pub fn call_unary_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: unary::contiguous::Kernel,
|
kernel_name: unary::contiguous::Kernel,
|
||||||
length: usize,
|
length: usize,
|
||||||
@ -375,7 +375,7 @@ pub fn call_unary_contiguous(
|
|||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
@ -385,14 +385,14 @@ pub fn call_unary_contiguous(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_unary_strided(
|
pub fn call_unary_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: unary::strided::Kernel,
|
name: unary::strided::Kernel,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -404,7 +404,7 @@ pub fn call_unary_strided(
|
|||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -412,14 +412,14 @@ pub fn call_unary_strided(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_binary_contiguous(
|
pub fn call_binary_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: binary::contiguous::Kernel,
|
kernel_name: binary::contiguous::Kernel,
|
||||||
length: usize,
|
length: usize,
|
||||||
@ -429,7 +429,7 @@ pub fn call_binary_contiguous(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, &left, &right, output));
|
set_params!(encoder, (length, &left, &right, output));
|
||||||
@ -440,14 +440,14 @@ pub fn call_binary_contiguous(
|
|||||||
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_binary_strided(
|
pub fn call_binary_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: binary::strided::Kernel,
|
name: binary::strided::Kernel,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -460,7 +460,7 @@ pub fn call_binary_strided(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||||
|
|
||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
let width: usize = shape.iter().product();
|
let width: usize = shape.iter().product();
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||||
@ -483,7 +483,7 @@ pub fn call_binary_strided(
|
|||||||
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -491,7 +491,7 @@ pub fn call_binary_strided(
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_cast_contiguous(
|
pub fn call_cast_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
@ -500,7 +500,7 @@ pub fn call_cast_contiguous(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, &input, output));
|
set_params!(encoder, (length, &input, output));
|
||||||
@ -509,14 +509,14 @@ pub fn call_cast_contiguous(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_cast_strided(
|
pub fn call_cast_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -526,7 +526,7 @@ pub fn call_cast_strided(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
@ -541,14 +541,14 @@ pub fn call_cast_strided(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_reduce_contiguous(
|
pub fn call_reduce_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
@ -559,7 +559,7 @@ pub fn call_reduce_contiguous(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, elements_to_sum, &input, output));
|
set_params!(encoder, (length, elements_to_sum, &input, output));
|
||||||
@ -585,14 +585,14 @@ pub fn call_reduce_contiguous(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_reduce_strided(
|
pub fn call_reduce_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -605,7 +605,7 @@ pub fn call_reduce_strided(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -634,14 +634,14 @@ pub fn call_reduce_strided(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_last_softmax(
|
pub fn call_last_softmax(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
@ -651,7 +651,7 @@ pub fn call_last_softmax(
|
|||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -682,14 +682,14 @@ pub fn call_last_softmax(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_rms_norm(
|
pub fn call_rms_norm(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
@ -702,7 +702,7 @@ pub fn call_rms_norm(
|
|||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -741,14 +741,14 @@ pub fn call_rms_norm(
|
|||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
|
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_layer_norm(
|
pub fn call_layer_norm(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
@ -763,7 +763,7 @@ pub fn call_layer_norm(
|
|||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -803,14 +803,14 @@ pub fn call_layer_norm(
|
|||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
|
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_rope_i(
|
pub fn call_rope_i(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
bh: usize,
|
bh: usize,
|
||||||
@ -824,7 +824,7 @@ pub fn call_rope_i(
|
|||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -844,14 +844,14 @@ pub fn call_rope_i(
|
|||||||
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_rope_thd(
|
pub fn call_rope_thd(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
b: usize,
|
b: usize,
|
||||||
@ -867,7 +867,7 @@ pub fn call_rope_thd(
|
|||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -889,14 +889,14 @@ pub fn call_rope_thd(
|
|||||||
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_rope(
|
pub fn call_rope(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
bh: usize,
|
bh: usize,
|
||||||
@ -911,7 +911,7 @@ pub fn call_rope(
|
|||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -932,14 +932,14 @@ pub fn call_rope(
|
|||||||
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_affine(
|
pub fn call_affine(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
size: usize,
|
size: usize,
|
||||||
@ -950,7 +950,7 @@ pub fn call_affine(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, add, &input, output));
|
set_params!(encoder, (size, mul, add, &input, output));
|
||||||
@ -959,14 +959,14 @@ pub fn call_affine(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_affine_strided(
|
pub fn call_affine_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -979,7 +979,7 @@ pub fn call_affine_strided(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1000,14 +1000,14 @@ pub fn call_affine_strided(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_powf(
|
pub fn call_powf(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
size: usize,
|
size: usize,
|
||||||
@ -1017,7 +1017,7 @@ pub fn call_powf(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, &input, output));
|
set_params!(encoder, (size, mul, &input, output));
|
||||||
@ -1026,14 +1026,14 @@ pub fn call_powf(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_powf_strided(
|
pub fn call_powf_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -1045,7 +1045,7 @@ pub fn call_powf_strided(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1057,14 +1057,14 @@ pub fn call_powf_strided(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_elu(
|
pub fn call_elu(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
size: usize,
|
size: usize,
|
||||||
@ -1074,7 +1074,7 @@ pub fn call_elu(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, &input, output));
|
set_params!(encoder, (size, mul, &input, output));
|
||||||
@ -1083,14 +1083,14 @@ pub fn call_elu(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_elu_strided(
|
pub fn call_elu_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -1102,7 +1102,7 @@ pub fn call_elu_strided(
|
|||||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1114,14 +1114,14 @@ pub fn call_elu_strided(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_where_cond_strided(
|
pub fn call_where_cond_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -1135,7 +1135,7 @@ pub fn call_where_cond_strided(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
@ -1164,14 +1164,14 @@ pub fn call_where_cond_strided(
|
|||||||
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_index_select(
|
pub fn call_index_select(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -1191,7 +1191,7 @@ pub fn call_index_select(
|
|||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
@ -1218,14 +1218,14 @@ pub fn call_index_select(
|
|||||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_gather(
|
pub fn call_gather(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -1242,7 +1242,7 @@ pub fn call_gather(
|
|||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
@ -1266,14 +1266,14 @@ pub fn call_gather(
|
|||||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_scatter_add(
|
pub fn call_scatter_add(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
src_shape: &[usize],
|
src_shape: &[usize],
|
||||||
@ -1291,7 +1291,7 @@ pub fn call_scatter_add(
|
|||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
@ -1315,14 +1315,14 @@ pub fn call_scatter_add(
|
|||||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_index_add(
|
pub fn call_index_add(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
src_shape: &[usize],
|
src_shape: &[usize],
|
||||||
@ -1341,7 +1341,7 @@ pub fn call_index_add(
|
|||||||
let ids_dim_size = ids_shape[0];
|
let ids_dim_size = ids_shape[0];
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
@ -1366,7 +1366,7 @@ pub fn call_index_add(
|
|||||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1453,7 +1453,7 @@ impl ConstantValues {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_gemm(
|
pub fn call_gemm(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
@ -1572,7 +1572,7 @@ pub fn call_gemm(
|
|||||||
};
|
};
|
||||||
let block_bytes = block_elements * bytes;
|
let block_bytes = block_elements * bytes;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||||
@ -1615,7 +1615,7 @@ pub fn call_gemm(
|
|||||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1623,7 +1623,7 @@ pub fn call_gemm(
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_im2col1d_strided(
|
pub fn call_im2col1d_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -1636,7 +1636,7 @@ pub fn call_im2col1d_strided(
|
|||||||
let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
|
let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
|
||||||
let dst_el = shape[0] * l_out * shape[1] * k_size;
|
let dst_el = shape[0] * l_out * shape[1] * k_size;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1646,7 +1646,7 @@ pub fn call_im2col1d_strided(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1654,7 +1654,7 @@ pub fn call_im2col1d_strided(
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_col2im1d(
|
pub fn call_col2im1d(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -1669,7 +1669,7 @@ pub fn call_col2im1d(
|
|||||||
let l_out = (l_in - 1) * stride + k_size;
|
let l_out = (l_in - 1) * stride + k_size;
|
||||||
let dst_el = shape[0] * c_out * l_out;
|
let dst_el = shape[0] * c_out * l_out;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1679,7 +1679,7 @@ pub fn call_col2im1d(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1687,7 +1687,7 @@ pub fn call_col2im1d(
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_im2col_strided(
|
pub fn call_im2col_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -1705,7 +1705,7 @@ pub fn call_im2col_strided(
|
|||||||
|
|
||||||
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
|
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1718,7 +1718,7 @@ pub fn call_im2col_strided(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1726,7 +1726,7 @@ pub fn call_im2col_strided(
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_upsample_nearest_2d(
|
pub fn call_upsample_nearest_2d(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -1741,7 +1741,7 @@ pub fn call_upsample_nearest_2d(
|
|||||||
let scale_w = shape[2] as f32 / out_w as f32;
|
let scale_w = shape[2] as f32 / out_w as f32;
|
||||||
let scale_h = shape[3] as f32 / out_h as f32;
|
let scale_h = shape[3] as f32 / out_h as f32;
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
@ -1750,7 +1750,7 @@ pub fn call_upsample_nearest_2d(
|
|||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1758,7 +1758,7 @@ pub fn call_upsample_nearest_2d(
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_random_uniform(
|
pub fn call_random_uniform(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
min: f32,
|
min: f32,
|
||||||
@ -1773,7 +1773,7 @@ pub fn call_random_uniform(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
|
|
||||||
let odd = (length % 2 != 0) as usize;
|
let odd = (length % 2 != 0) as usize;
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||||
@ -1788,7 +1788,7 @@ pub fn call_random_uniform(
|
|||||||
);
|
);
|
||||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1796,7 +1796,7 @@ pub fn call_random_uniform(
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_random_normal(
|
pub fn call_random_normal(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
mean: f32,
|
mean: f32,
|
||||||
@ -1806,7 +1806,7 @@ pub fn call_random_normal(
|
|||||||
buffer: &Buffer,
|
buffer: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
|
|
||||||
let odd = (length % 2 != 0) as usize;
|
let odd = (length % 2 != 0) as usize;
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||||
@ -1821,7 +1821,7 @@ pub fn call_random_normal(
|
|||||||
);
|
);
|
||||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -1847,7 +1847,7 @@ pub enum GgmlDType {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_quantized_matmul_mv_t(
|
pub fn call_quantized_matmul_mv_t(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
dtype: GgmlDType,
|
dtype: GgmlDType,
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
@ -1961,7 +1961,7 @@ pub fn call_quantized_matmul_mv_t(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -1993,7 +1993,7 @@ pub fn call_quantized_matmul_mv_t(
|
|||||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -2005,7 +2005,7 @@ fn divide(m: usize, b: usize) -> NSUInteger {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_pool2d(
|
pub fn call_pool2d(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
@ -2022,7 +2022,7 @@ pub fn call_pool2d(
|
|||||||
let dst_el = out_w * out_h * shape[0] * shape[1];
|
let dst_el = out_w * out_h * shape[0] * shape[1];
|
||||||
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
@ -2031,14 +2031,14 @@ pub fn call_pool2d(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_conv_transpose1d(
|
pub fn call_conv_transpose1d(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
dilation: usize,
|
dilation: usize,
|
||||||
@ -2061,7 +2061,7 @@ pub fn call_conv_transpose1d(
|
|||||||
let dst_el = c_out * l_out * b_size;
|
let dst_el = c_out * l_out * b_size;
|
||||||
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
@ -2084,7 +2084,7 @@ pub fn call_conv_transpose1d(
|
|||||||
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
|
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2108,7 +2108,7 @@ pub struct CallConvTranspose2dCfg<'a> {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_conv_transpose2d(
|
pub fn call_conv_transpose2d(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
cfg: CallConvTranspose2dCfg,
|
cfg: CallConvTranspose2dCfg,
|
||||||
@ -2119,7 +2119,7 @@ pub fn call_conv_transpose2d(
|
|||||||
let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size;
|
let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size;
|
||||||
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
@ -2143,14 +2143,14 @@ pub fn call_conv_transpose2d(
|
|||||||
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
|
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_arg_sort(
|
pub fn call_arg_sort(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
nrows: usize,
|
nrows: usize,
|
||||||
@ -2160,7 +2160,7 @@ pub fn call_arg_sort(
|
|||||||
dst: &Buffer,
|
dst: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = ep.encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
|
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
|
||||||
@ -2180,7 +2180,7 @@ pub fn call_arg_sort(
|
|||||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||||
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
|
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
ep.maybe_end_encoding(encoder);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,3 +160,26 @@ macro_rules! set_params {
|
|||||||
)*
|
)*
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait EncoderProvider {
|
||||||
|
fn encoder(&self) -> &ComputeCommandEncoderRef;
|
||||||
|
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderProvider for &metal::CommandBuffer {
|
||||||
|
fn encoder(&self) -> &ComputeCommandEncoderRef {
|
||||||
|
self.new_compute_command_encoder()
|
||||||
|
}
|
||||||
|
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
|
||||||
|
enc.end_encoding()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderProvider for &metal::CommandBufferRef {
|
||||||
|
fn encoder(&self) -> &ComputeCommandEncoderRef {
|
||||||
|
self.new_compute_command_encoder()
|
||||||
|
}
|
||||||
|
fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
|
||||||
|
enc.end_encoding()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user