mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00

* Support for (un)-batched rope. * Use 3d rope in the rope/ropei/rope_thd functions. * Get the CPU versions to work. * Fix the cuda version. * Adapt the metal side. * Fix the metal tests.
2670 lines
80 KiB
Rust
2670 lines
80 KiB
Rust
use metal::{
|
|
Buffer, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
|
|
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
|
};
|
|
use std::collections::HashMap;
|
|
use std::ffi::c_void;
|
|
use std::sync::RwLock;
|
|
pub mod mlx_gemm;
|
|
pub mod sort;
|
|
pub mod utils;
|
|
pub use mlx_gemm::{call_mlx_gemm, GemmDType};
|
|
pub use sort::{call_arg_sort, call_mlx_arg_sort};
|
|
pub use utils::BufferOffset;
|
|
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
|
|
|
|
const AFFINE: &str = include_str!("affine.metal");
|
|
const BINARY: &str = include_str!("binary.metal");
|
|
const CAST: &str = include_str!("cast.metal");
|
|
const CONV: &str = include_str!("conv.metal");
|
|
const FILL: &str = include_str!("fill.metal");
|
|
const INDEXING: &str = include_str!("indexing.metal");
|
|
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
|
const MLX_SORT: &str = include_str!("mlx_sort.metal");
|
|
const QUANTIZED: &str = include_str!("quantized.metal");
|
|
const RANDOM: &str = include_str!("random.metal");
|
|
const REDUCE: &str = include_str!("reduce.metal");
|
|
const SORT: &str = include_str!("sort.metal");
|
|
const TERNARY: &str = include_str!("ternary.metal");
|
|
const UNARY: &str = include_str!("unary.metal");
|
|
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
pub enum DType {
|
|
BF16,
|
|
F16,
|
|
F32,
|
|
I64,
|
|
U32,
|
|
U8,
|
|
}
|
|
|
|
impl DType {
|
|
fn size_in_bytes(&self) -> usize {
|
|
match self {
|
|
Self::U8 => 1,
|
|
Self::U32 => 4,
|
|
Self::I64 => 8,
|
|
Self::BF16 => 2,
|
|
Self::F16 => 2,
|
|
Self::F32 => 4,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
pub enum Source {
|
|
Affine,
|
|
Binary,
|
|
Cast,
|
|
Conv,
|
|
Fill,
|
|
Gemm,
|
|
Indexing,
|
|
MlxSort,
|
|
Quantized,
|
|
Random,
|
|
Reduce,
|
|
Sort,
|
|
Ternary,
|
|
Unary,
|
|
Sdpa,
|
|
}
|
|
|
|
pub mod copy2d {
|
|
pub struct Kernel(pub &'static str);
|
|
pub const FLOAT: Kernel = Kernel("copy2d_f32");
|
|
pub const HALF: Kernel = Kernel("copy2d_f16");
|
|
pub const BFLOAT: Kernel = Kernel("copy2d_bf16");
|
|
pub const I64: Kernel = Kernel("copy2d_i64");
|
|
pub const U32: Kernel = Kernel("copy2d_u32");
|
|
pub const U8: Kernel = Kernel("copy2d_u8");
|
|
}
|
|
|
|
macro_rules! ops{
|
|
($($name:ident),+) => {
|
|
|
|
pub mod contiguous {
|
|
pub struct Kernel(pub &'static str);
|
|
$(
|
|
pub mod $name {
|
|
use super::Kernel;
|
|
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
|
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
|
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
|
|
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
|
|
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32"));
|
|
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8"));
|
|
}
|
|
)+
|
|
pub mod copy {
|
|
use super::Kernel;
|
|
pub const FLOAT: Kernel = Kernel("copy_f32");
|
|
pub const HALF: Kernel = Kernel("copy_f16");
|
|
pub const BFLOAT: Kernel = Kernel("copy_bf16");
|
|
pub const I64: Kernel = Kernel("copy_i64");
|
|
pub const U32: Kernel = Kernel("copy_u32");
|
|
pub const U8: Kernel = Kernel("copy_u8");
|
|
}
|
|
}
|
|
|
|
pub mod contiguous_tiled {
|
|
pub struct Kernel(pub &'static str);
|
|
$(
|
|
pub mod $name {
|
|
use super::Kernel;
|
|
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_tiled"));
|
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled"));
|
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled"));
|
|
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled"));
|
|
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled"));
|
|
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled"));
|
|
}
|
|
)+
|
|
pub mod copy {
|
|
use super::Kernel;
|
|
pub const FLOAT: Kernel = Kernel("copy_f32_tiled");
|
|
pub const HALF: Kernel = Kernel("copy_f16_tiled");
|
|
pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled");
|
|
pub const I64: Kernel = Kernel("copy_i64_tiled");
|
|
pub const U32: Kernel = Kernel("copy_u32_tiled");
|
|
pub const U8: Kernel = Kernel("copy_u8_tiled");
|
|
}
|
|
}
|
|
|
|
pub mod strided {
|
|
pub struct Kernel(pub &'static str);
|
|
$(
|
|
pub mod $name {
|
|
use super::Kernel;
|
|
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
|
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
|
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
|
|
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
|
|
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided"));
|
|
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided"));
|
|
}
|
|
)+
|
|
pub mod copy {
|
|
use super::Kernel;
|
|
pub const FLOAT: Kernel = Kernel("copy_f32_strided");
|
|
pub const HALF: Kernel = Kernel("copy_f16_strided");
|
|
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
|
|
pub const I64: Kernel = Kernel("copy_i64_strided");
|
|
pub const U32: Kernel = Kernel("copy_u32_strided");
|
|
pub const U8: Kernel = Kernel("copy_u8_strided");
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
pub mod unary {
|
|
ops!(
|
|
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
|
tanh, recip, silu, sign, sigmoid, const_set
|
|
);
|
|
}
|
|
pub mod binary {
|
|
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
|
|
}
|
|
|
|
#[derive(thiserror::Error, Debug)]
|
|
pub enum MetalKernelError {
|
|
#[error("Could not lock kernel map: {0}")]
|
|
LockError(String),
|
|
#[error("Error while loading library: {0}")]
|
|
LoadLibraryError(String),
|
|
#[error("Error while loading function: {0}")]
|
|
LoadFunctionError(String),
|
|
#[error("Failed to create compute function")]
|
|
FailedToCreateComputeFunction,
|
|
#[error("Failed to create pipeline")]
|
|
FailedToCreatePipeline(String),
|
|
#[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")]
|
|
MatMulNonContiguous {
|
|
lhs_stride: Vec<usize>,
|
|
rhs_stride: Vec<usize>,
|
|
mnk: (usize, usize, usize),
|
|
},
|
|
#[error("Sdpa {variation} head size was {got}, expectd {expected:?}")]
|
|
SdpaHeadSizeMismatch {
|
|
variation: &'static str,
|
|
got: usize,
|
|
expected: Vec<usize>,
|
|
},
|
|
#[error("Sdpa {variation} got dtype {got:?}")]
|
|
SdpaHeadDTypeMismatch {
|
|
variation: &'static str,
|
|
got: SdpaDType,
|
|
},
|
|
}
|
|
|
|
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
|
fn from(e: std::sync::PoisonError<T>) -> Self {
|
|
Self::LockError(e.to_string())
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum KernelName {
|
|
Ref(&'static str),
|
|
Value(String),
|
|
}
|
|
|
|
impl AsRef<str> for KernelName {
|
|
fn as_ref(&self) -> &str {
|
|
match self {
|
|
Self::Ref(r) => r,
|
|
Self::Value(v) => v.as_str(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::hash::Hash for KernelName {
|
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
|
match self {
|
|
Self::Ref(r) => r.hash(state),
|
|
Self::Value(v) => v.hash(state),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl PartialEq for KernelName {
|
|
fn eq(&self, other: &Self) -> bool {
|
|
let v1: &str = self.as_ref();
|
|
let v2: &str = other.as_ref();
|
|
v1 == v2
|
|
}
|
|
}
|
|
|
|
impl Eq for KernelName {}
|
|
|
|
impl From<&'static str> for KernelName {
|
|
fn from(value: &'static str) -> Self {
|
|
Self::Ref(value)
|
|
}
|
|
}
|
|
|
|
impl From<String> for KernelName {
|
|
fn from(value: String) -> Self {
|
|
Self::Value(value)
|
|
}
|
|
}
|
|
|
|
type Libraries = HashMap<Source, Library>;
|
|
type Pipelines = HashMap<(KernelName, Option<ConstantValues>), ComputePipelineState>;
|
|
|
|
#[derive(Debug)]
|
|
pub struct Kernels {
|
|
libraries: RwLock<Libraries>,
|
|
pipelines: RwLock<Pipelines>,
|
|
}
|
|
|
|
impl Default for Kernels {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl Kernels {
|
|
pub fn new() -> Self {
|
|
let libraries = RwLock::new(Libraries::new());
|
|
let pipelines = RwLock::new(Pipelines::new());
|
|
Self {
|
|
libraries,
|
|
pipelines,
|
|
}
|
|
}
|
|
|
|
fn get_library_source(&self, source: Source) -> &'static str {
|
|
match source {
|
|
Source::Affine => AFFINE,
|
|
Source::Binary => BINARY,
|
|
Source::Cast => CAST,
|
|
Source::Conv => CONV,
|
|
Source::Fill => FILL,
|
|
Source::Gemm => MLX_GEMM,
|
|
Source::Indexing => INDEXING,
|
|
Source::MlxSort => MLX_SORT,
|
|
Source::Quantized => QUANTIZED,
|
|
Source::Random => RANDOM,
|
|
Source::Reduce => REDUCE,
|
|
Source::Sort => SORT,
|
|
Source::Ternary => TERNARY,
|
|
Source::Unary => UNARY,
|
|
Source::Sdpa => SDPA,
|
|
}
|
|
}
|
|
|
|
/// Load the give library from its [`source`].
|
|
/// If this has been previously loaded it will just fetch it from cache.
|
|
pub fn load_library(
|
|
&self,
|
|
device: &Device,
|
|
source: Source,
|
|
) -> Result<Library, MetalKernelError> {
|
|
let mut libraries = self.libraries.write()?;
|
|
if let Some(lib) = libraries.get(&source) {
|
|
Ok(lib.clone())
|
|
} else {
|
|
let lib = {
|
|
let source_content = self.get_library_source(source);
|
|
device
|
|
.new_library_with_source(source_content, &CompileOptions::new())
|
|
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
|
};
|
|
libraries.insert(source, lib.clone());
|
|
Ok(lib)
|
|
}
|
|
}
|
|
|
|
fn load_function(
|
|
&self,
|
|
device: &Device,
|
|
source: Source,
|
|
name: &str,
|
|
constants: Option<FunctionConstantValues>,
|
|
) -> Result<Function, MetalKernelError> {
|
|
let func = self
|
|
.load_library(device, source)?
|
|
.get_function(name, constants)
|
|
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
|
Ok(func)
|
|
}
|
|
|
|
/// Load the give pipeline
|
|
/// loads the library from source, then gets the function [`name`] from
|
|
/// that source
|
|
fn load_pipeline_with_constants(
|
|
&self,
|
|
device: &Device,
|
|
source: Source,
|
|
name: impl Into<KernelName>,
|
|
constants: Option<ConstantValues>,
|
|
) -> Result<ComputePipelineState, MetalKernelError> {
|
|
let mut pipelines = self.pipelines.write()?;
|
|
let key = (name.into(), constants);
|
|
if let Some(pipeline) = pipelines.get(&key) {
|
|
Ok(pipeline.clone())
|
|
} else {
|
|
let (name, constants) = key;
|
|
let func = self.load_function(
|
|
device,
|
|
source,
|
|
name.as_ref(),
|
|
constants.as_ref().map(|c| c.function_constant_values()),
|
|
)?;
|
|
let pipeline = device
|
|
.new_compute_pipeline_state_with_function(&func)
|
|
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
|
|
pipelines.insert((name, constants), pipeline.clone());
|
|
|
|
Ok(pipeline)
|
|
}
|
|
}
|
|
|
|
/// Load the give pipeline
|
|
/// loads the library from source, then gets the function [`name`] from
|
|
/// that source (without constants)
|
|
pub fn load_pipeline(
|
|
&self,
|
|
device: &Device,
|
|
source: Source,
|
|
name: impl Into<KernelName>,
|
|
) -> Result<ComputePipelineState, MetalKernelError> {
|
|
self.load_pipeline_with_constants(device, source, name, None)
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_copy2d(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: copy2d::Kernel,
|
|
input: &Buffer,
|
|
output: &Buffer,
|
|
d1: usize,
|
|
d2: usize,
|
|
src_s: usize,
|
|
dst_s: usize,
|
|
src_o_in_bytes: usize,
|
|
dst_o_in_bytes: usize,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
d1 as i64,
|
|
d2 as i64,
|
|
src_s as i64,
|
|
dst_s as i64,
|
|
(input, src_o_in_bytes),
|
|
(output, dst_o_in_bytes)
|
|
)
|
|
);
|
|
|
|
let grid_dims = MTLSize {
|
|
width: d1 as u64,
|
|
height: d2 as u64,
|
|
depth: 1,
|
|
};
|
|
let group_dims = get_block_dims(d1 as u64, d2 as u64, 1);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_threads(grid_dims, group_dims);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_const_set_contiguous_tiled(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: unary::contiguous_tiled::Kernel,
|
|
length: usize,
|
|
input: impl EncoderParam,
|
|
output: BufferOffset,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
let tile_size = 2;
|
|
let tiles = length.div_ceil(tile_size);
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (length, input, &output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
|
|
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_const_set_contiguous(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: unary::contiguous::Kernel,
|
|
length: usize,
|
|
input: impl EncoderParam,
|
|
output: BufferOffset,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (length, input, &output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_const_set_strided(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: unary::strided::Kernel,
|
|
shape: &[usize],
|
|
input: impl EncoderParam,
|
|
strides: &[usize],
|
|
output: BufferOffset,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
|
|
|
let length: usize = shape.iter().product();
|
|
let num_dims: usize = shape.len();
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(encoder, (length, num_dims, shape, strides, input, &output));
|
|
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_unary_contiguous_tiled(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: unary::contiguous_tiled::Kernel,
|
|
length: usize,
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
let tile_size = 2;
|
|
let tiles = length.div_ceil(tile_size);
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (length, &input, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_unary_contiguous(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: unary::contiguous::Kernel,
|
|
length: usize,
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (length, &input, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_unary_strided(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: unary::strided::Kernel,
|
|
shape: &[usize],
|
|
input: BufferOffset,
|
|
strides: &[usize],
|
|
output: BufferOffset,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
|
|
|
let length: usize = shape.iter().product();
|
|
let num_dims: usize = shape.len();
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(encoder, (length, num_dims, shape, strides, &input, &output));
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_binary_contiguous(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: binary::contiguous::Kernel,
|
|
length: usize,
|
|
left: BufferOffset,
|
|
right: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (length, &left, &right, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
|
|
encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_binary_strided(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: binary::strided::Kernel,
|
|
shape: &[usize],
|
|
left_input: BufferOffset,
|
|
left_strides: &[usize],
|
|
right_input: BufferOffset,
|
|
right_strides: &[usize],
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
|
|
|
let num_dims: usize = shape.len();
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
let width: usize = shape.iter().product();
|
|
let length: usize = shape.iter().product();
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
length,
|
|
num_dims,
|
|
shape,
|
|
left_strides,
|
|
right_strides,
|
|
&left_input,
|
|
&right_input,
|
|
output
|
|
)
|
|
);
|
|
encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_cast_contiguous(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
length: usize,
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (length, &input, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_cast_strided(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
shape: &[usize],
|
|
input: BufferOffset,
|
|
input_strides: &[usize],
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
let length: usize = shape.iter().product();
|
|
|
|
set_params!(
|
|
encoder,
|
|
(length, shape.len(), shape, input_strides, &input, output)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_reduce_contiguous(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
shape: &[usize],
|
|
out_length: usize,
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let length = shape.iter().product::<usize>();
|
|
let num_dims = shape.len();
|
|
let work_per_threadgroup = length / out_length;
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
length,
|
|
num_dims,
|
|
shape,
|
|
work_per_threadgroup,
|
|
&input,
|
|
output
|
|
)
|
|
);
|
|
|
|
let thread_group_count = MTLSize {
|
|
width: out_length as u64,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let width = std::cmp::min(
|
|
pipeline.max_total_threads_per_threadgroup(),
|
|
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
|
|
);
|
|
|
|
let thread_group_size = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_reduce_strided(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
shape: &[usize],
|
|
strides: &[usize],
|
|
out_length: usize,
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let length: usize = shape.iter().product();
|
|
let num_dims = shape.len();
|
|
let work_per_threadgroup = length / out_length;
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
length,
|
|
num_dims,
|
|
shape,
|
|
strides,
|
|
work_per_threadgroup,
|
|
&input,
|
|
output
|
|
)
|
|
);
|
|
|
|
let thread_group_count = MTLSize {
|
|
width: out_length as u64,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let width = std::cmp::min(
|
|
pipeline.max_total_threads_per_threadgroup(),
|
|
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
|
|
);
|
|
|
|
let thread_group_size = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_last_softmax(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
length: usize,
|
|
elements: usize,
|
|
input: &Buffer,
|
|
input_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let work_per_threadgroup = elements;
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(length, work_per_threadgroup, (input, input_offset), output)
|
|
);
|
|
|
|
let out_length = length / work_per_threadgroup;
|
|
|
|
let thread_group_count = MTLSize {
|
|
width: out_length as NSUInteger,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let width = std::cmp::min(
|
|
pipeline.max_total_threads_per_threadgroup(),
|
|
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
|
|
);
|
|
|
|
let thread_group_size = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_rms_norm(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
length: usize,
|
|
elements_to_sum: usize,
|
|
eps: f32,
|
|
input: &Buffer,
|
|
input_offset: usize,
|
|
alpha: &Buffer,
|
|
alpha_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
length,
|
|
elements_to_sum,
|
|
(input, input_offset),
|
|
output,
|
|
(alpha, alpha_offset),
|
|
eps
|
|
)
|
|
);
|
|
|
|
let out_length = length / elements_to_sum;
|
|
|
|
let thread_group_count = MTLSize {
|
|
width: out_length as u64,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let width = std::cmp::min(
|
|
pipeline.max_total_threads_per_threadgroup(),
|
|
elements_to_sum as u64,
|
|
)
|
|
.next_power_of_two();
|
|
|
|
let thread_group_size = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_layer_norm(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
length: usize,
|
|
elements_to_sum: usize,
|
|
eps: f32,
|
|
input: &Buffer,
|
|
input_offset: usize,
|
|
alpha: &Buffer,
|
|
alpha_offset: usize,
|
|
beta: &Buffer,
|
|
beta_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
length,
|
|
elements_to_sum,
|
|
(input, input_offset),
|
|
output,
|
|
(alpha, alpha_offset),
|
|
(beta, beta_offset),
|
|
eps
|
|
)
|
|
);
|
|
|
|
let out_length = length / elements_to_sum;
|
|
|
|
let thread_group_count = MTLSize {
|
|
width: out_length as u64,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let width = std::cmp::min(
|
|
pipeline.max_total_threads_per_threadgroup(),
|
|
elements_to_sum as u64,
|
|
)
|
|
.next_power_of_two();
|
|
|
|
let thread_group_size = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_rope_i(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
bh: usize,
|
|
td: usize,
|
|
stride_b: usize,
|
|
src: &Buffer,
|
|
src_offset: usize,
|
|
cos: &Buffer,
|
|
cos_offset: usize,
|
|
sin: &Buffer,
|
|
sin_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
bh,
|
|
td,
|
|
stride_b,
|
|
(src, src_offset),
|
|
(cos, cos_offset),
|
|
(sin, sin_offset),
|
|
output
|
|
)
|
|
);
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2);
|
|
encoder.use_resource(src, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_rope_thd(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
b: usize,
|
|
t: usize,
|
|
h: usize,
|
|
d: usize,
|
|
stride_b: usize,
|
|
src: &Buffer,
|
|
src_offset: usize,
|
|
cos: &Buffer,
|
|
cos_offset: usize,
|
|
sin: &Buffer,
|
|
sin_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
b,
|
|
t,
|
|
h,
|
|
d,
|
|
stride_b,
|
|
(src, src_offset),
|
|
(cos, cos_offset),
|
|
(sin, sin_offset),
|
|
output
|
|
)
|
|
);
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (b * t * h * d) / 2);
|
|
encoder.use_resource(src, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_rope(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
bh: usize,
|
|
td: usize,
|
|
d: usize,
|
|
stride_b: usize,
|
|
src: &Buffer,
|
|
src_offset: usize,
|
|
cos: &Buffer,
|
|
cos_offset: usize,
|
|
sin: &Buffer,
|
|
sin_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
bh,
|
|
td,
|
|
d,
|
|
stride_b,
|
|
(src, src_offset),
|
|
(cos, cos_offset),
|
|
(sin, sin_offset),
|
|
output
|
|
)
|
|
);
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2);
|
|
encoder.use_resource(src, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_affine(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
size: usize,
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
mul: f32,
|
|
add: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (size, mul, add, &input, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_affine_strided(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
input: BufferOffset,
|
|
input_stride: &[usize],
|
|
output: &Buffer,
|
|
mul: f32,
|
|
add: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
let size: usize = shape.iter().product();
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
size,
|
|
shape.len(),
|
|
shape,
|
|
input_stride,
|
|
mul,
|
|
add,
|
|
&input,
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_powf(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
size: usize,
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
mul: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (size, mul, &input, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_powf_strided(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
input: BufferOffset,
|
|
input_stride: &[usize],
|
|
output: &Buffer,
|
|
mul: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
let size: usize = shape.iter().product();
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(size, shape.len(), shape, input_stride, mul, &input, output)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_elu(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
size: usize,
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
mul: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (size, mul, &input, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_elu_strided(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
input: BufferOffset,
|
|
input_stride: &[usize],
|
|
output: &Buffer,
|
|
mul: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
let size: usize = shape.iter().product();
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(size, shape.len(), shape, input_stride, mul, &input, output)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_where_cond_strided(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
cond: BufferOffset,
|
|
cond_stride: &[usize],
|
|
left: BufferOffset,
|
|
left_stride: &[usize],
|
|
right: BufferOffset,
|
|
right_stride: &[usize],
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
let size: usize = shape.iter().product();
|
|
let rank = shape.len();
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
size,
|
|
rank,
|
|
shape,
|
|
cond_stride,
|
|
left_stride,
|
|
right_stride,
|
|
&cond,
|
|
&left,
|
|
&right,
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
|
|
encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_index_select(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
ids_size: usize,
|
|
dim: usize,
|
|
contiguous: bool,
|
|
src_dims: &[usize],
|
|
src_strides: &[usize],
|
|
input: BufferOffset,
|
|
ids: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let left_size: usize = shape[..dim].iter().product();
|
|
let right_size: usize = shape[dim + 1..].iter().product();
|
|
let src_dim_size = shape[dim];
|
|
let dst_el = ids_size * left_size * right_size;
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
dst_el,
|
|
left_size,
|
|
src_dim_size,
|
|
right_size,
|
|
ids_size,
|
|
contiguous,
|
|
src_dims,
|
|
src_strides,
|
|
&input,
|
|
&ids,
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_gather(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
ids_size: usize,
|
|
dim: usize,
|
|
input: BufferOffset,
|
|
ids: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let left_size: usize = shape[..dim].iter().product();
|
|
let right_size: usize = shape[dim + 1..].iter().product();
|
|
let src_dim_size = shape[dim];
|
|
let dst_el = ids_size * left_size * right_size;
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
dst_el,
|
|
left_size,
|
|
src_dim_size,
|
|
right_size,
|
|
ids_size,
|
|
&input,
|
|
&ids,
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_scatter(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
src_shape: &[usize],
|
|
dst_shape: &[usize],
|
|
dim: usize,
|
|
input: BufferOffset,
|
|
ids: BufferOffset,
|
|
output: BufferOffset,
|
|
) -> Result<(), MetalKernelError> {
|
|
let left_size: usize = src_shape[..dim].iter().product();
|
|
let right_size: usize = src_shape[dim + 1..].iter().product();
|
|
let src_dim_size = src_shape[dim];
|
|
let dst_el = left_size * right_size;
|
|
let dst_dim_size = dst_shape[dim];
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
dst_el,
|
|
left_size,
|
|
src_dim_size,
|
|
right_size,
|
|
dst_dim_size,
|
|
&input,
|
|
&ids,
|
|
&output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_index_add(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
src_shape: &[usize],
|
|
dst_shape: &[usize],
|
|
ids_shape: &[usize],
|
|
dim: usize,
|
|
input: BufferOffset,
|
|
ids: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let left_size: usize = src_shape[..dim].iter().product();
|
|
let right_size: usize = src_shape[dim + 1..].iter().product();
|
|
let src_dim_size = src_shape[dim];
|
|
let dst_el = left_size * right_size;
|
|
let dst_dim_size = dst_shape[dim];
|
|
let ids_dim_size = ids_shape[0];
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
dst_el,
|
|
left_size,
|
|
src_dim_size,
|
|
right_size,
|
|
dst_dim_size,
|
|
ids_dim_size,
|
|
&input,
|
|
&ids,
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Debug, PartialEq)]
|
|
pub enum Value {
|
|
USize(usize),
|
|
Bool(bool),
|
|
F32(f32),
|
|
U16(u16),
|
|
}
|
|
|
|
impl std::hash::Hash for Value {
|
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
|
match self {
|
|
Value::F32(v) => v.to_bits().hash(state),
|
|
Value::USize(v) => v.hash(state),
|
|
Value::U16(v) => v.hash(state),
|
|
Value::Bool(v) => v.hash(state),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Value {
|
|
fn data_type(&self) -> MTLDataType {
|
|
match self {
|
|
Value::USize(_) => MTLDataType::UInt,
|
|
Value::F32(_) => MTLDataType::Float,
|
|
Value::U16(_) => MTLDataType::UShort,
|
|
Value::Bool(_) => MTLDataType::Bool,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Not true, good enough for our purposes.
|
|
impl Eq for Value {}
|
|
|
|
#[derive(Debug, Eq, PartialEq, Hash)]
|
|
struct ConstantValues(Vec<(usize, Value)>);
|
|
|
|
impl ConstantValues {
|
|
pub fn new(values: Vec<(usize, Value)>) -> Self {
|
|
Self(values)
|
|
}
|
|
|
|
fn function_constant_values(&self) -> FunctionConstantValues {
|
|
let f = FunctionConstantValues::new();
|
|
for (index, value) in &self.0 {
|
|
let ty = value.data_type();
|
|
match value {
|
|
Value::USize(v) => {
|
|
f.set_constant_value_at_index(
|
|
v as *const usize as *const c_void,
|
|
ty,
|
|
*index as u64,
|
|
);
|
|
}
|
|
Value::F32(v) => {
|
|
f.set_constant_value_at_index(
|
|
v as *const f32 as *const c_void,
|
|
ty,
|
|
*index as u64,
|
|
);
|
|
}
|
|
Value::U16(v) => {
|
|
f.set_constant_value_at_index(
|
|
v as *const u16 as *const c_void,
|
|
ty,
|
|
*index as u64,
|
|
);
|
|
}
|
|
Value::Bool(v) => {
|
|
f.set_constant_value_at_index(
|
|
v as *const bool as *const c_void,
|
|
ty,
|
|
*index as u64,
|
|
);
|
|
}
|
|
}
|
|
}
|
|
f
|
|
}
|
|
}
|
|
|
|
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
|
pub enum SdpaDType {
|
|
BF16,
|
|
F16,
|
|
F32,
|
|
}
|
|
|
|
/// SDPA full is supported when:
|
|
/// - q head dim == 64, 128
|
|
/// - no mask
|
|
/// - q heads == kv heads
|
|
/// - final type != bf16 (TODO maybe just template this kernel too?)
|
|
/// - q,k,v are contiguous
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_sdpa_full(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
q_offset: usize,
|
|
q_shape: &[usize],
|
|
q_buffer: &Buffer,
|
|
k_offset: usize,
|
|
k_buffer: &Buffer,
|
|
v_offset: usize,
|
|
v_buffer: &Buffer,
|
|
output: &Buffer,
|
|
alpha: f32,
|
|
softcapping: f32,
|
|
itype: SdpaDType,
|
|
) -> Result<(), MetalKernelError> {
|
|
#[derive(Debug)]
|
|
#[repr(C)]
|
|
struct MLXFastAttentionParams {
|
|
m: i32,
|
|
n: i32,
|
|
k: i32,
|
|
|
|
ldq: i32, // ldq == ldo
|
|
ldk: i32,
|
|
ldv: i32,
|
|
lds: i32,
|
|
ldo: i32,
|
|
|
|
tiles_n: i32,
|
|
tiles_m: i32,
|
|
|
|
batch_stride_q: i32,
|
|
batch_stride_k: i32,
|
|
batch_stride_v: i32,
|
|
batch_stride_o: i32,
|
|
|
|
swizzle_log: i32,
|
|
gemm_n_iterations_aligned: i32,
|
|
gemm_k_iterations_aligned: i32,
|
|
gemm_sv_m_block_iterations: i32,
|
|
|
|
batch_ndim: i32,
|
|
alpha: f32,
|
|
softcapping: f32,
|
|
}
|
|
|
|
let bk = q_shape.last().unwrap();
|
|
|
|
const BN: usize = 16;
|
|
const BM: usize = 16;
|
|
const WM: usize = 2;
|
|
const WN: usize = 2;
|
|
|
|
let name = match (bk, itype) {
|
|
(32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half",
|
|
(64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half",
|
|
(96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half",
|
|
(128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half",
|
|
(256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half",
|
|
(32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float",
|
|
(64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float",
|
|
(96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float",
|
|
(128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float",
|
|
(256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float",
|
|
(other, SdpaDType::F16 | SdpaDType::F32) => {
|
|
return Err(MetalKernelError::SdpaHeadSizeMismatch {
|
|
variation: "full",
|
|
got: *other,
|
|
expected: vec![32, 64, 96, 128, 256],
|
|
})
|
|
}
|
|
(_, SdpaDType::BF16) => {
|
|
return Err(MetalKernelError::SdpaHeadDTypeMismatch {
|
|
variation: "full",
|
|
got: SdpaDType::BF16,
|
|
})
|
|
}
|
|
};
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
// q = (bs, qhead, seq, hidden)
|
|
// k/v = (bs, kv_head, seq, hidden)
|
|
|
|
let qseq = q_shape[q_shape.len() - 2];
|
|
|
|
let m = q_shape[q_shape.len() - 2];
|
|
let n = m;
|
|
let k = q_shape[q_shape.len() - 1];
|
|
let bs_out = q_shape[0] * q_shape[1];
|
|
|
|
let batch_shape = [q_shape[0] * q_shape[1]];
|
|
let dk = q_shape[q_shape.len() - 1];
|
|
let ldq = dk;
|
|
let ldk = dk;
|
|
let ldv = dk;
|
|
let lds = BN;
|
|
let ldo = dk;
|
|
|
|
let tn = 1;
|
|
let tm = m.div_ceil(BM);
|
|
|
|
let b_stride_q = dk * qseq;
|
|
let b_stride_k = dk * qseq;
|
|
let b_stride_v = dk * qseq;
|
|
let b_stride_o = dk * qseq;
|
|
let swizzle_log = 0;
|
|
let gemm_n_iterations_aligned = n.div_ceil(BN);
|
|
let gemm_k_iterations_aligned = k.div_ceil(*bk);
|
|
let gemm_sv_m_block_iterations = m.div_ceil(BM);
|
|
let batch_ndim = batch_shape.len();
|
|
|
|
let alpha = if softcapping != 1. {
|
|
alpha / softcapping
|
|
} else {
|
|
alpha
|
|
};
|
|
|
|
let params = MLXFastAttentionParams {
|
|
m: m as i32,
|
|
n: n as i32,
|
|
k: k as i32,
|
|
ldq: ldq as i32,
|
|
ldk: ldk as i32,
|
|
ldv: ldv as i32,
|
|
lds: lds as i32,
|
|
ldo: ldo as i32,
|
|
tiles_n: tn,
|
|
tiles_m: tm as i32,
|
|
batch_stride_q: b_stride_q as i32,
|
|
batch_stride_k: b_stride_k as i32,
|
|
batch_stride_v: b_stride_v as i32,
|
|
batch_stride_o: b_stride_o as i32,
|
|
swizzle_log,
|
|
gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32,
|
|
gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32,
|
|
gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32,
|
|
batch_ndim: batch_ndim as i32,
|
|
alpha,
|
|
softcapping,
|
|
};
|
|
let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o];
|
|
|
|
impl EncoderParam for MLXFastAttentionParams {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_bytes(
|
|
position,
|
|
core::mem::size_of::<MLXFastAttentionParams>() as u64,
|
|
&data as *const MLXFastAttentionParams as *const c_void,
|
|
);
|
|
}
|
|
}
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
(q_buffer, q_offset),
|
|
(k_buffer, k_offset),
|
|
(v_buffer, v_offset),
|
|
output,
|
|
params,
|
|
&batch_shape[..],
|
|
&batch_strides[..]
|
|
)
|
|
);
|
|
|
|
let grid_dims = MTLSize {
|
|
width: 1,
|
|
height: tm as u64,
|
|
depth: bs_out as u64,
|
|
};
|
|
let group_dims = MTLSize {
|
|
width: 32,
|
|
height: WM as u64,
|
|
depth: WN as u64,
|
|
};
|
|
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(grid_dims, group_dims);
|
|
Ok(())
|
|
}
|
|
|
|
/// SDPA full is supported when:
|
|
/// - q head dim == 64, 96, 128
|
|
/// - no mask
|
|
/// - q,k,v are contiguous
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_sdpa_vector(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
q_offset: usize,
|
|
q_shape: &[usize],
|
|
q_buffer: &Buffer,
|
|
k_offset: usize,
|
|
k_shape: &[usize],
|
|
k_stride: &[usize],
|
|
k_buffer: &Buffer,
|
|
v_offset: usize,
|
|
v_stride: &[usize],
|
|
v_buffer: &Buffer,
|
|
output: &Buffer,
|
|
alpha: f32,
|
|
softcapping: f32,
|
|
itype: SdpaDType,
|
|
) -> Result<(), MetalKernelError> {
|
|
let bk = q_shape.last().unwrap();
|
|
|
|
let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
|
|
let n = k_shape[2] as i32;
|
|
let b = (q_shape[0] * q_shape[1]) as i32;
|
|
let kstride = k_stride[1];
|
|
let vstride = v_stride[1];
|
|
|
|
let name = match (bk, itype) {
|
|
(32, SdpaDType::F16) => "sdpa_vector_float16_t_32",
|
|
(64, SdpaDType::F16) => "sdpa_vector_float16_t_64",
|
|
(96, SdpaDType::F16) => "sdpa_vector_float16_t_96",
|
|
(128, SdpaDType::F16) => "sdpa_vector_float16_t_128",
|
|
(256, SdpaDType::F16) => "sdpa_vector_float16_t_256",
|
|
(32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32",
|
|
(64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64",
|
|
(96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96",
|
|
(128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128",
|
|
(256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256",
|
|
(32, SdpaDType::F32) => "sdpa_vector_float_32",
|
|
(64, SdpaDType::F32) => "sdpa_vector_float_64",
|
|
(96, SdpaDType::F32) => "sdpa_vector_float_96",
|
|
(128, SdpaDType::F32) => "sdpa_vector_float_128",
|
|
(256, SdpaDType::F32) => "sdpa_vector_float_256",
|
|
(other, _) => {
|
|
return Err(MetalKernelError::SdpaHeadSizeMismatch {
|
|
variation: "vector",
|
|
got: *other,
|
|
expected: vec![32, 64, 96, 128, 256],
|
|
})
|
|
}
|
|
};
|
|
|
|
let alpha = if softcapping != 1. {
|
|
alpha / softcapping
|
|
} else {
|
|
alpha
|
|
};
|
|
|
|
let constants = Some(ConstantValues::new(vec![(
|
|
20,
|
|
Value::Bool(/* sdpa_vector_has_mask */ false),
|
|
)]));
|
|
|
|
let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
// q = (bs, qhead, seq, hidden)
|
|
// k/v = (bs, kv_head, kv_seq, hidden)
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
(q_buffer, q_offset),
|
|
(k_buffer, k_offset),
|
|
(v_buffer, v_offset),
|
|
output,
|
|
gqa_factor,
|
|
n,
|
|
kstride,
|
|
vstride,
|
|
alpha,
|
|
softcapping
|
|
)
|
|
);
|
|
|
|
let grid_dims = MTLSize {
|
|
width: 1,
|
|
height: b as u64,
|
|
depth: 1_u64,
|
|
};
|
|
let group_dims = MTLSize {
|
|
width: 1024,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(grid_dims, group_dims);
|
|
Ok(())
|
|
}
|
|
|
|
pub const SDPA_2PASS_BLOCKS: usize = 32;
|
|
|
|
/// SDPA vector 2pass is supported when:
|
|
/// - q head dim == 64, 96, 128
|
|
/// - no mask
|
|
/// - q,k,v are contiguous
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_sdpa_vector_2pass(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
q_offset: usize,
|
|
q_shape: &[usize],
|
|
q_buffer: &Buffer,
|
|
k_offset: usize,
|
|
k_shape: &[usize],
|
|
k_stride: &[usize],
|
|
k_buffer: &Buffer,
|
|
v_offset: usize,
|
|
v_stride: &[usize],
|
|
v_buffer: &Buffer,
|
|
output: &Buffer,
|
|
intermediate: &Buffer,
|
|
sums: &Buffer,
|
|
maxs: &Buffer,
|
|
alpha: f32,
|
|
softcapping: f32,
|
|
itype: SdpaDType,
|
|
) -> Result<(), MetalKernelError> {
|
|
let bk = q_shape.last().unwrap();
|
|
|
|
// First pass
|
|
{
|
|
let name_pass1 = match (bk, itype) {
|
|
(32, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_32",
|
|
(64, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_64",
|
|
(96, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_96",
|
|
(128, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_128",
|
|
(256, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_256",
|
|
(32, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_32",
|
|
(64, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_64",
|
|
(96, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_96",
|
|
(128, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_128",
|
|
(256, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_256",
|
|
(32, SdpaDType::F32) => "sdpa_vector_2pass_1_float_32",
|
|
(64, SdpaDType::F32) => "sdpa_vector_2pass_1_float_64",
|
|
(96, SdpaDType::F32) => "sdpa_vector_2pass_1_float_96",
|
|
(128, SdpaDType::F32) => "sdpa_vector_2pass_1_float_128",
|
|
(256, SdpaDType::F32) => "sdpa_vector_2pass_1_float_256",
|
|
(other, _) => {
|
|
return Err(MetalKernelError::SdpaHeadSizeMismatch {
|
|
variation: "vector_2pass_1",
|
|
got: *other,
|
|
expected: vec![32, 64, 96, 128, 256],
|
|
})
|
|
}
|
|
};
|
|
|
|
let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
|
|
let n = k_shape[2] as i32;
|
|
let b = (q_shape[0] * q_shape[1]) as i32;
|
|
let kstride = k_stride[1];
|
|
let vstride = v_stride[1];
|
|
|
|
let alpha = if softcapping != 1. {
|
|
alpha / softcapping
|
|
} else {
|
|
alpha
|
|
};
|
|
|
|
let constants = Some(ConstantValues::new(vec![(
|
|
20,
|
|
Value::Bool(/* sdpa_vector_has_mask */ false),
|
|
)]));
|
|
|
|
let pipeline =
|
|
kernels.load_pipeline_with_constants(device, Source::Sdpa, name_pass1, constants)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
// q = (bs, qhead, seq, hidden)
|
|
// k/v = (bs, kv_head, kv_seq, hidden)
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
(q_buffer, q_offset),
|
|
(k_buffer, k_offset),
|
|
(v_buffer, v_offset),
|
|
intermediate,
|
|
sums,
|
|
maxs,
|
|
gqa_factor,
|
|
n,
|
|
kstride,
|
|
vstride,
|
|
alpha,
|
|
softcapping
|
|
)
|
|
);
|
|
|
|
let grid_dims = MTLSize {
|
|
width: 1,
|
|
height: b as u64,
|
|
depth: SDPA_2PASS_BLOCKS as u64,
|
|
};
|
|
let group_dims = MTLSize {
|
|
width: 8 * 32,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(intermediate, metal::MTLResourceUsage::Write);
|
|
encoder.use_resource(sums, metal::MTLResourceUsage::Write);
|
|
encoder.use_resource(maxs, metal::MTLResourceUsage::Write);
|
|
|
|
encoder.dispatch_thread_groups(grid_dims, group_dims);
|
|
}
|
|
|
|
// Final pass
|
|
{
|
|
let name_pass2 = match (bk, itype) {
|
|
(32, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_32",
|
|
(64, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_64",
|
|
(96, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_96",
|
|
(128, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_128",
|
|
(256, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_256",
|
|
(32, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_32",
|
|
(64, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_64",
|
|
(96, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_96",
|
|
(128, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_128",
|
|
(256, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_256",
|
|
(32, SdpaDType::F32) => "sdpa_vector_2pass_2_float_32",
|
|
(64, SdpaDType::F32) => "sdpa_vector_2pass_2_float_64",
|
|
(96, SdpaDType::F32) => "sdpa_vector_2pass_2_float_96",
|
|
(128, SdpaDType::F32) => "sdpa_vector_2pass_2_float_128",
|
|
(256, SdpaDType::F32) => "sdpa_vector_2pass_2_float_256",
|
|
(other, _) => {
|
|
return Err(MetalKernelError::SdpaHeadSizeMismatch {
|
|
variation: "vector_2pass_2",
|
|
got: *other,
|
|
expected: vec![32, 64, 96, 128, 256],
|
|
})
|
|
}
|
|
};
|
|
|
|
let b = (q_shape[0] * q_shape[1]) as i32;
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name_pass2)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
// q = (bs, qhead, seq, hidden)
|
|
// k/v = (bs, kv_head, kv_seq, hidden)
|
|
|
|
set_params!(encoder, (intermediate, sums, maxs, output));
|
|
|
|
let grid_dims = MTLSize {
|
|
width: 1,
|
|
height: b as u64,
|
|
depth: 1,
|
|
};
|
|
let group_dims = MTLSize {
|
|
width: 1024,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
encoder.use_resource(intermediate, metal::MTLResourceUsage::Write);
|
|
encoder.use_resource(sums, metal::MTLResourceUsage::Write);
|
|
encoder.use_resource(maxs, metal::MTLResourceUsage::Write);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
|
|
encoder.dispatch_thread_groups(grid_dims, group_dims);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_im2col1d_strided(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
strides: &[usize],
|
|
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
|
let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
|
|
let dst_el = shape[0] * l_out * shape[1] * k_size;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(
|
|
encoder,
|
|
(dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output)
|
|
);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_col2im1d(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
k_size: usize,
|
|
stride: usize,
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
|
let l_in = shape[1];
|
|
let c_out = shape[2];
|
|
let l_out = (l_in - 1) * stride + k_size;
|
|
let dst_el = shape[0] * c_out * l_out;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(
|
|
encoder,
|
|
(dst_el, l_out, l_in, c_out, k_size, stride, &input, output)
|
|
);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_im2col_strided(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
strides: &[usize],
|
|
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
|
|
|
let h = shape[2];
|
|
let w = shape[3];
|
|
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
|
|
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
|
|
|
|
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
|
|
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input,
|
|
output
|
|
)
|
|
);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_upsample_nearest_2d(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
strides: &[usize],
|
|
out_w: usize,
|
|
out_h: usize,
|
|
input: BufferOffset,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
|
let dst_el = out_w * out_h * shape[0] * shape[1];
|
|
let scale_w = shape[2] as f32 / out_w as f32;
|
|
let scale_h = shape[3] as f32 / out_h as f32;
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(
|
|
encoder,
|
|
(out_w, out_h, scale_w, scale_h, shape, strides, &input, output)
|
|
);
|
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_random_uniform(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
min: f32,
|
|
max: f32,
|
|
length: usize,
|
|
seed: &Buffer,
|
|
buffer: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
if min >= max {
|
|
return Err(MetalKernelError::LoadLibraryError(
|
|
"min must be less than max".to_string(),
|
|
));
|
|
}
|
|
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
|
|
let odd = (length % 2 != 0) as usize;
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (length, min, max, seed, buffer));
|
|
|
|
encoder.use_resource(
|
|
seed,
|
|
metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write,
|
|
);
|
|
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_random_normal(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
mean: f32,
|
|
stddev: f32,
|
|
length: usize,
|
|
seed: &Buffer,
|
|
buffer: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
|
|
let odd = (length % 2 != 0) as usize;
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
|
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (length, mean, stddev, seed, buffer));
|
|
|
|
encoder.use_resource(
|
|
seed,
|
|
metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write,
|
|
);
|
|
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub enum GgmlDType {
|
|
Q4_0,
|
|
Q4_1,
|
|
Q5_0,
|
|
Q5_1,
|
|
Q8_0,
|
|
Q8_1,
|
|
Q2K,
|
|
Q3K,
|
|
Q4K,
|
|
Q5K,
|
|
Q6K,
|
|
Q8K,
|
|
F16,
|
|
F32,
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_quantized_matmul_mv_t(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
dtype: GgmlDType,
|
|
(b, m, n, k): (usize, usize, usize, usize),
|
|
lhs: &Buffer,
|
|
lhs_offset: usize,
|
|
rhs: &Buffer,
|
|
dst_offset: usize,
|
|
dst: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
// Everything is in reverse
|
|
let ne00 = k as i64;
|
|
let ne01 = n as i64;
|
|
let ne02 = b as i64;
|
|
let ne03 = 1i64;
|
|
|
|
let nb00 = 0i64;
|
|
let nb01 = 0i64;
|
|
let nb02 = 0i64;
|
|
|
|
let ne10 = k as i64;
|
|
let ne11 = m as i64;
|
|
let ne12 = b as i64;
|
|
let ne13 = 1i64;
|
|
|
|
let nb10 = 0i64;
|
|
let nb11 = 0i64;
|
|
let nb12 = 0i64;
|
|
|
|
let ne0 = n as i64;
|
|
let ne1 = m as i64;
|
|
let r2: u32 = (ne12 / ne02) as u32;
|
|
let r3: u32 = (ne13 / ne03) as u32;
|
|
|
|
let (nth0, nth1, align) = match dtype {
|
|
GgmlDType::Q4_0
|
|
| GgmlDType::Q4_1
|
|
| GgmlDType::Q5_0
|
|
| GgmlDType::Q5_1
|
|
| GgmlDType::Q8_0
|
|
| GgmlDType::Q8_1 => {
|
|
let nth0 = 8;
|
|
let nth1 = 8;
|
|
let align = 8;
|
|
(nth0, nth1, align)
|
|
}
|
|
GgmlDType::Q2K => {
|
|
// Fixing a bug in Metal for GGML
|
|
// https://github.com/ggerganov/llama.cpp/blob/b8109bc0139f15a5b321909f47510b89dca47ffc/ggml-metal.m#L1576
|
|
let nth0 = 2;
|
|
let nth1 = 32;
|
|
let align = 4;
|
|
(nth0, nth1, align)
|
|
}
|
|
GgmlDType::Q4K => {
|
|
let nth0 = 4;
|
|
let nth1 = 8;
|
|
let align = 4;
|
|
(nth0, nth1, align)
|
|
}
|
|
GgmlDType::Q3K | GgmlDType::Q5K => {
|
|
let nth0 = 2;
|
|
let nth1 = 32;
|
|
let align = 4;
|
|
(nth0, nth1, align)
|
|
}
|
|
GgmlDType::Q6K => {
|
|
let nth0 = 2;
|
|
let nth1 = 32;
|
|
let align = 2;
|
|
(nth0, nth1, align)
|
|
}
|
|
GgmlDType::F16 | GgmlDType::Q8K => {
|
|
// Original implem uses rows
|
|
let nth0 = 32;
|
|
let nth1 = 1;
|
|
let align = 8;
|
|
(nth0, nth1, align)
|
|
}
|
|
GgmlDType::F32 => {
|
|
let nth0 = 32;
|
|
let nth1 = 1;
|
|
let align = 8;
|
|
(nth0, nth1, align)
|
|
}
|
|
};
|
|
let thread_groups_count = MTLSize {
|
|
width: divide(ne01 as usize, align),
|
|
height: ne11 as u64,
|
|
depth: (ne12 * ne13) as u64,
|
|
};
|
|
let threads_per_threadgroup = MTLSize {
|
|
width: nth0,
|
|
height: nth1,
|
|
depth: 1,
|
|
};
|
|
let name = match dtype {
|
|
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
|
|
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
|
|
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
|
|
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
|
|
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
|
|
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
|
|
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
|
|
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
|
|
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
|
|
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
|
|
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
|
|
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
|
|
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
|
|
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
|
|
};
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
rhs,
|
|
(lhs, lhs_offset),
|
|
(dst, dst_offset),
|
|
ne00,
|
|
ne01,
|
|
ne02,
|
|
nb00,
|
|
nb01,
|
|
nb02,
|
|
ne10,
|
|
ne11,
|
|
ne12,
|
|
nb10,
|
|
nb11,
|
|
nb12,
|
|
ne0,
|
|
ne1,
|
|
r2,
|
|
r3
|
|
)
|
|
);
|
|
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
|
|
|
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
|
Ok(())
|
|
}
|
|
|
|
fn divide(m: usize, b: usize) -> NSUInteger {
|
|
m.div_ceil(b) as NSUInteger
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_pool2d(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
strides: &[usize],
|
|
out_w: usize,
|
|
out_h: usize,
|
|
w_k: usize,
|
|
h_k: usize,
|
|
w_stride: usize,
|
|
h_stride: usize,
|
|
input: &Buffer,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let dst_el = out_w * out_h * shape[0] * shape[1];
|
|
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(
|
|
encoder,
|
|
(w_k, h_k, w_stride, h_stride, shape, strides, input, output)
|
|
);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_conv_transpose1d(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
dilation: usize,
|
|
stride: usize,
|
|
padding: usize,
|
|
out_padding: usize,
|
|
c_out: usize,
|
|
l_out: usize,
|
|
b_size: usize,
|
|
src_shape: &[usize],
|
|
src_strides: &[usize],
|
|
kernel_shape: &[usize],
|
|
kernel_strides: &[usize],
|
|
input: &Buffer,
|
|
input_offset: usize,
|
|
kernel: &Buffer,
|
|
kernel_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let dst_el = c_out * l_out * b_size;
|
|
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
l_out,
|
|
stride,
|
|
padding,
|
|
out_padding,
|
|
dilation,
|
|
src_shape,
|
|
src_strides,
|
|
kernel_shape,
|
|
kernel_strides,
|
|
(input, input_offset),
|
|
(kernel, kernel_offset),
|
|
output
|
|
)
|
|
);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
pub struct CallConvTranspose2dCfg<'a> {
|
|
pub dilation: usize,
|
|
pub stride: usize,
|
|
pub padding: usize,
|
|
pub output_padding: usize,
|
|
pub c_out: usize,
|
|
pub out_w: usize,
|
|
pub out_h: usize,
|
|
pub b_size: usize,
|
|
pub input_dims: &'a [usize],
|
|
pub input_stride: &'a [usize],
|
|
pub kernel_dims: &'a [usize],
|
|
pub kernel_stride: &'a [usize],
|
|
pub input_offset: usize,
|
|
pub kernel_offset: usize,
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_conv_transpose2d(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
cfg: CallConvTranspose2dCfg,
|
|
input: &Buffer,
|
|
kernel: &Buffer,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size;
|
|
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
cfg.out_w,
|
|
cfg.out_h,
|
|
cfg.stride,
|
|
cfg.padding,
|
|
cfg.output_padding,
|
|
cfg.dilation,
|
|
cfg.input_dims,
|
|
cfg.input_stride,
|
|
cfg.kernel_dims,
|
|
cfg.kernel_stride,
|
|
(input, cfg.input_offset),
|
|
(kernel, cfg.kernel_offset),
|
|
output
|
|
)
|
|
);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
pub fn call_const_fill(
|
|
device: &Device,
|
|
ep: impl EncoderProvider,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
length: usize,
|
|
output: &Buffer,
|
|
v: impl EncoderParam,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
|
let encoder = ep.encoder();
|
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
set_params!(encoder, (output, v, length));
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests;
|