mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
1427 lines
43 KiB
Rust
1427 lines
43 KiB
Rust
use metal::{
|
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
|
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
|
};
|
|
use std::collections::HashMap;
|
|
use std::ffi::c_void;
|
|
use std::sync::RwLock;
|
|
|
|
const AFFINE: &str = include_str!("affine.metal");
|
|
const INDEXING: &str = include_str!("indexing.metal");
|
|
const UNARY: &str = include_str!("unary.metal");
|
|
const BINARY: &str = include_str!("binary.metal");
|
|
const TERNARY: &str = include_str!("ternary.metal");
|
|
const CAST: &str = include_str!("cast.metal");
|
|
const REDUCE: &str = include_str!("reduce.metal");
|
|
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
|
|
|
/// Most kernels apply similarly across the tensors
|
|
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
|
/// actual total buffer length).
|
|
/// Then kernels can just do their op on their single point in the buffer.
|
|
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
|
let size = length as u64;
|
|
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
|
let count = (size + width - 1) / width;
|
|
let thread_group_count = MTLSize {
|
|
width: count,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let thread_group_size = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
(thread_group_count, thread_group_size)
|
|
}
|
|
|
|
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
|
|
<P as EncoderParam>::set_param(encoder, position, data)
|
|
}
|
|
|
|
/// Helper functions to create the various objects on the compute command encoder
|
|
/// on a single line.
|
|
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
|
trait EncoderParam {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
|
}
|
|
macro_rules! primitive {
|
|
($type:ty) => {
|
|
impl EncoderParam for $type {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_bytes(
|
|
position,
|
|
core::mem::size_of::<$type>() as u64,
|
|
&data as *const $type as *const c_void,
|
|
);
|
|
}
|
|
}
|
|
};
|
|
}
|
|
primitive!(usize);
|
|
primitive!(u32);
|
|
primitive!(f32);
|
|
|
|
impl<T> EncoderParam for &[T] {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_bytes(
|
|
position,
|
|
core::mem::size_of_val(data) as u64,
|
|
data.as_ptr() as *const c_void,
|
|
);
|
|
}
|
|
}
|
|
|
|
impl EncoderParam for &Buffer {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data), 0);
|
|
}
|
|
}
|
|
impl EncoderParam for (&Buffer, usize) {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
|
}
|
|
}
|
|
impl EncoderParam for &mut Buffer {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data), 0);
|
|
}
|
|
}
|
|
impl EncoderParam for (&mut Buffer, usize) {
|
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
|
}
|
|
}
|
|
|
|
macro_rules! set_params {
|
|
($encoder:ident, ($($param:expr),+)) => (
|
|
let mut _index = 0;
|
|
$(
|
|
set_param($encoder, _index, $param);
|
|
_index += 1;
|
|
)*
|
|
);
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
pub enum Source {
|
|
Affine,
|
|
Indexing,
|
|
Unary,
|
|
Binary,
|
|
Ternary,
|
|
Cast,
|
|
Reduce,
|
|
Mfa,
|
|
}
|
|
|
|
macro_rules! ops{
|
|
($($name:ident),+) => {
|
|
|
|
pub mod contiguous {
|
|
pub struct Kernel(pub &'static str);
|
|
$(
|
|
pub mod $name {
|
|
use super::Kernel;
|
|
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
|
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
|
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
|
|
}
|
|
)+
|
|
pub mod copy {
|
|
use super::Kernel;
|
|
pub const FLOAT: Kernel = Kernel("copy_f32");
|
|
pub const HALF: Kernel = Kernel("copy_f16");
|
|
pub const BFLOAT: Kernel = Kernel("copy_bf16");
|
|
pub const U32: Kernel = Kernel("copy_u32");
|
|
pub const U8: Kernel = Kernel("copy_u8");
|
|
}
|
|
}
|
|
|
|
pub mod strided {
|
|
pub struct Kernel(pub &'static str);
|
|
$(
|
|
pub mod $name {
|
|
use super::Kernel;
|
|
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
|
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
|
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
|
|
}
|
|
)+
|
|
pub mod copy {
|
|
use super::Kernel;
|
|
pub const FLOAT: Kernel = Kernel("copy_f32_strided");
|
|
pub const HALF: Kernel = Kernel("copy_f16_strided");
|
|
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
|
|
pub const U32: Kernel = Kernel("copy_u32_strided");
|
|
pub const U8: Kernel = Kernel("copy_u8_strided");
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
pub mod unary {
|
|
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
|
|
}
|
|
pub mod binary {
|
|
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
|
|
}
|
|
|
|
#[derive(thiserror::Error, Debug)]
|
|
pub enum MetalKernelError {
|
|
#[error("Could not lock kernel map: {0}")]
|
|
LockError(String),
|
|
#[error("Error while loading library: {0}")]
|
|
LoadLibraryError(String),
|
|
#[error("Error while loading function: {0:?}")]
|
|
LoadFunctionError(String),
|
|
#[error("Failed to create compute function")]
|
|
FailedToCreateComputeFunction,
|
|
#[error("Failed to create pipeline")]
|
|
FailedToCreatePipeline(String),
|
|
#[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")]
|
|
MatMulNonContiguous {
|
|
lhs_stride: Vec<usize>,
|
|
rhs_stride: Vec<usize>,
|
|
mnk: (usize, usize, usize),
|
|
},
|
|
}
|
|
|
|
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
|
fn from(e: std::sync::PoisonError<T>) -> Self {
|
|
Self::LockError(e.to_string())
|
|
}
|
|
}
|
|
|
|
type Libraries = HashMap<Source, Library>;
|
|
type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
|
|
|
|
#[derive(Debug)]
|
|
pub struct Kernels {
|
|
libraries: RwLock<Libraries>,
|
|
pipelines: RwLock<Pipelines>,
|
|
fence: metal::Fence,
|
|
}
|
|
|
|
impl Kernels {
|
|
pub fn new(fence: metal::Fence) -> Self {
|
|
let libraries = RwLock::new(Libraries::new());
|
|
let pipelines = RwLock::new(Pipelines::new());
|
|
Self {
|
|
libraries,
|
|
pipelines,
|
|
fence,
|
|
}
|
|
}
|
|
|
|
fn get_library_source(&self, source: Source) -> &'static str {
|
|
match source {
|
|
Source::Affine => AFFINE,
|
|
Source::Unary => UNARY,
|
|
Source::Binary => BINARY,
|
|
Source::Ternary => TERNARY,
|
|
Source::Indexing => INDEXING,
|
|
Source::Cast => CAST,
|
|
Source::Reduce => REDUCE,
|
|
Source::Mfa => panic!("Invalid lib"),
|
|
}
|
|
}
|
|
|
|
/// Load the give library from its [`source`].
|
|
/// If this has been previously loaded it will just fetch it from cache.
|
|
pub fn load_library(
|
|
&self,
|
|
device: &Device,
|
|
source: Source,
|
|
) -> Result<Library, MetalKernelError> {
|
|
let mut libraries = self.libraries.write()?;
|
|
if let Some(lib) = libraries.get(&source) {
|
|
Ok(lib.clone())
|
|
} else {
|
|
let lib = match source {
|
|
Source::Mfa => {
|
|
let source_data = MFA;
|
|
device.new_library_with_data(source_data).map_err(|e| {
|
|
MetalKernelError::LoadLibraryError(format!(
|
|
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
|
))
|
|
})?
|
|
}
|
|
source => {
|
|
let source_content = self.get_library_source(source);
|
|
device
|
|
.new_library_with_source(source_content, &CompileOptions::new())
|
|
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
|
}
|
|
};
|
|
libraries.insert(source, lib.clone());
|
|
Ok(lib)
|
|
}
|
|
}
|
|
|
|
fn load_function(
|
|
&self,
|
|
device: &Device,
|
|
source: Source,
|
|
name: &'static str,
|
|
constants: Option<FunctionConstantValues>,
|
|
) -> Result<Function, MetalKernelError> {
|
|
let func = self
|
|
.load_library(device, source)?
|
|
.get_function(name, constants)
|
|
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
|
Ok(func)
|
|
}
|
|
|
|
/// Load the give pipeline
|
|
/// loads the library from source, then gets the function [`name`] from
|
|
/// that source
|
|
fn load_pipeline_with_constants(
|
|
&self,
|
|
device: &Device,
|
|
source: Source,
|
|
name: &'static str,
|
|
constants: Option<ConstantValues>,
|
|
) -> Result<ComputePipelineState, MetalKernelError> {
|
|
let mut pipelines = self.pipelines.write()?;
|
|
let key = (name, constants);
|
|
if let Some(pipeline) = pipelines.get(&key) {
|
|
Ok(pipeline.clone())
|
|
} else {
|
|
let (name, constants) = key;
|
|
let func = self.load_function(
|
|
device,
|
|
source,
|
|
name,
|
|
constants.as_ref().map(|c| c.function_constant_values()),
|
|
)?;
|
|
let pipeline = device
|
|
.new_compute_pipeline_state_with_function(&func)
|
|
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
|
|
pipelines.insert((name, constants), pipeline.clone());
|
|
|
|
Ok(pipeline)
|
|
}
|
|
}
|
|
|
|
/// Load the give pipeline
|
|
/// loads the library from source, then gets the function [`name`] from
|
|
/// that source (without constants)
|
|
pub fn load_pipeline(
|
|
&self,
|
|
device: &Device,
|
|
source: Source,
|
|
name: &'static str,
|
|
) -> Result<ComputePipelineState, MetalKernelError> {
|
|
self.load_pipeline_with_constants(device, source, name, None)
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_unary_contiguous(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
kernel_name: unary::contiguous::Kernel,
|
|
length: usize,
|
|
input: &Buffer,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (length, input, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_unary_strided(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: unary::strided::Kernel,
|
|
shape: &[usize],
|
|
input: &Buffer,
|
|
strides: &[usize],
|
|
offset: usize,
|
|
output: &Buffer,
|
|
output_offset: usize,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
|
|
|
let num_dims: usize = shape.len();
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
let length: usize = shape.iter().product();
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
length,
|
|
num_dims,
|
|
shape,
|
|
strides,
|
|
(input, offset),
|
|
(output, output_offset)
|
|
)
|
|
);
|
|
|
|
let width: usize = shape.iter().product();
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_binary_contiguous(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
kernel_name: binary::contiguous::Kernel,
|
|
length: usize,
|
|
left: &Buffer,
|
|
right: &Buffer,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (length, left, right, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
|
|
encoder.use_resource(left, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_binary_strided(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: binary::strided::Kernel,
|
|
shape: &[usize],
|
|
left_input: &Buffer,
|
|
left_strides: &[usize],
|
|
left_offset: usize,
|
|
right_input: &Buffer,
|
|
right_strides: &[usize],
|
|
right_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
|
|
|
let num_dims: usize = shape.len();
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
let width: usize = shape.iter().product();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
let length: usize = shape.iter().product();
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
length,
|
|
num_dims,
|
|
shape,
|
|
left_strides,
|
|
right_strides,
|
|
(left_input, left_offset),
|
|
(right_input, right_offset),
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
|
|
|
encoder.use_resource(left_input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_cast_contiguous(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
length: usize,
|
|
input: &Buffer,
|
|
input_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (length, (input, input_offset), output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_cast_strided(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
shape: &[usize],
|
|
input: &Buffer,
|
|
input_strides: &[usize],
|
|
input_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
let length: usize = shape.iter().product();
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
length,
|
|
shape.len(),
|
|
shape,
|
|
input_strides,
|
|
(input, input_offset),
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
pub fn call_reduce_contiguous(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
length: usize,
|
|
out_length: usize,
|
|
input: &Buffer,
|
|
input_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
let elements_to_sum = length / out_length;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(length, elements_to_sum, (input, input_offset), output)
|
|
);
|
|
|
|
let thread_group_count = MTLSize {
|
|
width: out_length as u64,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let width = std::cmp::min(
|
|
pipeline.max_total_threads_per_threadgroup(),
|
|
(elements_to_sum as u64 + 2 - 1) / 2,
|
|
)
|
|
.next_power_of_two();
|
|
|
|
let thread_group_size = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
pub fn call_reduce_strided(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
shape: &[usize],
|
|
strides: &[usize],
|
|
out_length: usize,
|
|
input: &Buffer,
|
|
input_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let length: usize = shape.iter().product();
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
let elements_to_sum = length / out_length;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
shape.len(),
|
|
shape,
|
|
strides,
|
|
elements_to_sum,
|
|
(input, input_offset),
|
|
output
|
|
)
|
|
);
|
|
|
|
let thread_group_count = MTLSize {
|
|
width: out_length as u64,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let width = std::cmp::min(
|
|
pipeline.max_total_threads_per_threadgroup(),
|
|
elements_to_sum as u64,
|
|
)
|
|
.next_power_of_two();
|
|
|
|
let thread_group_size = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_last_softmax(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
kernel_name: &'static str,
|
|
length: usize,
|
|
elements_to_sum: usize,
|
|
input: &Buffer,
|
|
input_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(length, elements_to_sum, (input, input_offset), output)
|
|
);
|
|
|
|
let out_length = length / elements_to_sum;
|
|
|
|
let thread_group_count = MTLSize {
|
|
width: out_length as u64,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
let width = std::cmp::min(
|
|
pipeline.max_total_threads_per_threadgroup(),
|
|
elements_to_sum as u64,
|
|
)
|
|
.next_power_of_two();
|
|
|
|
let thread_group_size = MTLSize {
|
|
width,
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_affine(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
size: usize,
|
|
input: &Buffer,
|
|
output: &Buffer,
|
|
mul: f32,
|
|
add: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (size, mul, add, input, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_affine_strided(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
input: &Buffer,
|
|
input_stride: &[usize],
|
|
input_offset: usize,
|
|
output: &Buffer,
|
|
mul: f32,
|
|
add: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
let size: usize = shape.iter().product();
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
size,
|
|
shape.len(),
|
|
shape,
|
|
input_stride,
|
|
mul,
|
|
add,
|
|
(input, input_offset),
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_powf(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
size: usize,
|
|
input: &Buffer,
|
|
output: &Buffer,
|
|
mul: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (size, mul, input, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_powf_strided(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
input: &Buffer,
|
|
input_stride: &[usize],
|
|
input_offset: usize,
|
|
output: &Buffer,
|
|
mul: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
let size: usize = shape.iter().product();
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
size,
|
|
shape.len(),
|
|
shape,
|
|
input_stride,
|
|
mul,
|
|
(input, input_offset),
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_elu(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
size: usize,
|
|
input: &Buffer,
|
|
output: &Buffer,
|
|
mul: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(encoder, (size, mul, input, output));
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_elu_strided(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
input: &Buffer,
|
|
input_stride: &[usize],
|
|
input_offset: usize,
|
|
output: &Buffer,
|
|
mul: f32,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
|
let size: usize = shape.iter().product();
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
size,
|
|
shape.len(),
|
|
shape,
|
|
input_stride,
|
|
mul,
|
|
(input, input_offset),
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
pub fn call_where_cond_strided(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
cond: &Buffer,
|
|
(cond_stride, cond_offset): (&[usize], usize),
|
|
left: &Buffer,
|
|
(left_stride, left_offset): (&[usize], usize),
|
|
right: &Buffer,
|
|
(right_stride, right_offset): (&[usize], usize),
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
let size: usize = shape.iter().product();
|
|
let rank = shape.len();
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
size,
|
|
rank,
|
|
shape,
|
|
cond_stride,
|
|
left_stride,
|
|
right_stride,
|
|
(cond, cond_offset),
|
|
(left, left_offset),
|
|
(right, right_offset),
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
|
|
|
encoder.use_resource(cond, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(left, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_index_select(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
ids_size: usize,
|
|
dim: usize,
|
|
input: &Buffer,
|
|
ids: &Buffer,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let left_size: usize = shape[..dim].iter().product();
|
|
let right_size: usize = shape[dim + 1..].iter().product();
|
|
let src_dim_size = shape[dim];
|
|
let dst_el = ids_size * left_size * right_size;
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
dst_el,
|
|
left_size,
|
|
src_dim_size,
|
|
right_size,
|
|
ids_size,
|
|
input,
|
|
ids,
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_gather(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
shape: &[usize],
|
|
ids_size: usize,
|
|
dim: usize,
|
|
input: &Buffer,
|
|
input_offset: usize,
|
|
ids: &Buffer,
|
|
ids_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let left_size: usize = shape[..dim].iter().product();
|
|
let right_size: usize = shape[dim + 1..].iter().product();
|
|
let src_dim_size = shape[dim];
|
|
let dst_el = ids_size * left_size * right_size;
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
dst_el,
|
|
left_size,
|
|
src_dim_size,
|
|
right_size,
|
|
ids_size,
|
|
(input, input_offset),
|
|
(ids, ids_offset),
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
pub fn call_scatter_add(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
src_shape: &[usize],
|
|
dst_shape: &[usize],
|
|
dim: usize,
|
|
input: &Buffer,
|
|
input_offset: usize,
|
|
ids: &Buffer,
|
|
ids_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let left_size: usize = src_shape[..dim].iter().product();
|
|
let right_size: usize = src_shape[dim + 1..].iter().product();
|
|
let src_dim_size = src_shape[dim];
|
|
let dst_el = left_size * right_size;
|
|
let dst_dim_size = dst_shape[dim];
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
dst_el,
|
|
left_size,
|
|
src_dim_size,
|
|
right_size,
|
|
dst_dim_size,
|
|
(input, input_offset),
|
|
(ids, ids_offset),
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
pub fn call_index_add(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
src_shape: &[usize],
|
|
dst_shape: &[usize],
|
|
ids_shape: &[usize],
|
|
dim: usize,
|
|
input: &Buffer,
|
|
input_offset: usize,
|
|
ids: &Buffer,
|
|
ids_offset: usize,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
let left_size: usize = src_shape[..dim].iter().product();
|
|
let right_size: usize = src_shape[dim + 1..].iter().product();
|
|
let src_dim_size = src_shape[dim];
|
|
let dst_el = left_size * right_size;
|
|
let dst_dim_size = dst_shape[dim];
|
|
let ids_dim_size = ids_shape[0];
|
|
|
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
|
|
set_params!(
|
|
encoder,
|
|
(
|
|
dst_el,
|
|
left_size,
|
|
src_dim_size,
|
|
right_size,
|
|
dst_dim_size,
|
|
ids_dim_size,
|
|
(input, input_offset),
|
|
(ids, ids_offset),
|
|
output
|
|
)
|
|
);
|
|
|
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
|
|
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Debug, PartialEq)]
|
|
pub enum Value {
|
|
USize(usize),
|
|
Bool(bool),
|
|
F32(f32),
|
|
U16(u16),
|
|
}
|
|
|
|
impl std::hash::Hash for Value {
|
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
|
match self {
|
|
Value::F32(v) => v.to_bits().hash(state),
|
|
Value::USize(v) => v.hash(state),
|
|
Value::U16(v) => v.hash(state),
|
|
Value::Bool(v) => v.hash(state),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Value {
|
|
fn data_type(&self) -> MTLDataType {
|
|
match self {
|
|
Value::USize(_) => MTLDataType::UInt,
|
|
Value::F32(_) => MTLDataType::Float,
|
|
Value::U16(_) => MTLDataType::UShort,
|
|
Value::Bool(_) => MTLDataType::Bool,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Not true, good enough for our purposes.
|
|
impl Eq for Value {}
|
|
|
|
#[derive(Debug, Eq, PartialEq, Hash)]
|
|
struct ConstantValues(Vec<(usize, Value)>);
|
|
|
|
impl ConstantValues {
|
|
pub fn new(values: Vec<(usize, Value)>) -> Self {
|
|
Self(values)
|
|
}
|
|
|
|
fn function_constant_values(&self) -> FunctionConstantValues {
|
|
let f = FunctionConstantValues::new();
|
|
for (index, value) in &self.0 {
|
|
let ty = value.data_type();
|
|
match value {
|
|
Value::USize(v) => {
|
|
f.set_constant_value_at_index(
|
|
v as *const usize as *const c_void,
|
|
ty,
|
|
*index as u64,
|
|
);
|
|
}
|
|
Value::F32(v) => {
|
|
f.set_constant_value_at_index(
|
|
v as *const f32 as *const c_void,
|
|
ty,
|
|
*index as u64,
|
|
);
|
|
}
|
|
Value::U16(v) => {
|
|
f.set_constant_value_at_index(
|
|
v as *const u16 as *const c_void,
|
|
ty,
|
|
*index as u64,
|
|
);
|
|
}
|
|
Value::Bool(v) => {
|
|
f.set_constant_value_at_index(
|
|
v as *const bool as *const c_void,
|
|
ty,
|
|
*index as u64,
|
|
);
|
|
}
|
|
}
|
|
}
|
|
f
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn call_gemm(
|
|
device: &Device,
|
|
command_buffer: &CommandBufferRef,
|
|
kernels: &Kernels,
|
|
name: &'static str,
|
|
(b, m, n, k): (usize, usize, usize, usize),
|
|
lhs_stride: &[usize],
|
|
lhs_offset: usize,
|
|
lhs_buffer: &Buffer,
|
|
rhs_stride: &[usize],
|
|
rhs_offset: usize,
|
|
rhs_buffer: &Buffer,
|
|
output: &Buffer,
|
|
) -> Result<(), MetalKernelError> {
|
|
assert!(rhs_stride.len() >= 2);
|
|
assert!(lhs_stride.len() >= 2);
|
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
|
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
|
|
false
|
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
|
true
|
|
} else {
|
|
return Err(MetalKernelError::MatMulNonContiguous {
|
|
lhs_stride: lhs_stride.to_vec(),
|
|
rhs_stride: rhs_stride.to_vec(),
|
|
mnk: (m, n, k),
|
|
})?;
|
|
};
|
|
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
|
|
false
|
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
|
true
|
|
} else {
|
|
return Err(MetalKernelError::MatMulNonContiguous {
|
|
lhs_stride: lhs_stride.to_vec(),
|
|
rhs_stride: rhs_stride.to_vec(),
|
|
mnk: (m, n, k),
|
|
})?;
|
|
};
|
|
let d_trans = false;
|
|
let alpha = 1.0f32;
|
|
let beta = 0.0f32;
|
|
let batched = b > 1;
|
|
let fused_activation = false;
|
|
let fused_bias = false;
|
|
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
|
let m_simd = 16;
|
|
let n_simd = 8;
|
|
let k_simd = 64;
|
|
let m_splits = 1;
|
|
let n_splits = 1;
|
|
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
|
} else {
|
|
let m_simd = 40;
|
|
let n_simd = 40;
|
|
let k_simd = 8;
|
|
let m_splits = 1;
|
|
let n_splits = 1;
|
|
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
|
};
|
|
let constants = Some(ConstantValues::new(vec![
|
|
(0, Value::USize(m)),
|
|
(1, Value::USize(n)),
|
|
(2, Value::USize(k)),
|
|
(10, Value::Bool(a_trans)),
|
|
(11, Value::Bool(b_trans)),
|
|
(13, Value::Bool(d_trans)),
|
|
(20, Value::F32(alpha)),
|
|
(21, Value::F32(beta)),
|
|
(100, Value::Bool(batched)),
|
|
(101, Value::Bool(fused_activation)),
|
|
// Garbage
|
|
(102, Value::Bool(false)),
|
|
(103, Value::Bool(false)),
|
|
(113, Value::Bool(false)),
|
|
(50_000, Value::Bool(false)),
|
|
// End garbage
|
|
(200, Value::U16(m_simd)),
|
|
(201, Value::U16(n_simd)),
|
|
(202, Value::U16(k_simd)),
|
|
(210, Value::U16(m_splits)),
|
|
(211, Value::U16(n_splits)),
|
|
(50_001, Value::Bool(fused_bias)),
|
|
]));
|
|
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
|
let m_group = m_simd * m_splits;
|
|
let n_group = n_simd * n_splits;
|
|
|
|
let a_block_length = m_group * k_simd;
|
|
let b_block_length = k_simd * n_group;
|
|
|
|
let mut block_elements = a_block_length + b_block_length;
|
|
if (m % 8 != 0) && (n % 8 != 0) {
|
|
let c_block_length = m_group * n_group;
|
|
block_elements = std::cmp::max(c_block_length, block_elements)
|
|
}
|
|
if fused_bias {
|
|
if d_trans {
|
|
block_elements = std::cmp::max(block_elements, m_group);
|
|
} else {
|
|
block_elements = std::cmp::max(block_elements, n_group);
|
|
}
|
|
}
|
|
let bytes = match name {
|
|
"sgemm" => 4,
|
|
"hgemm" => 2,
|
|
other => {
|
|
return Err(MetalKernelError::LoadLibraryError(format!(
|
|
"{other} is not a valid kernel for gemm"
|
|
)));
|
|
}
|
|
};
|
|
let block_bytes = block_elements * bytes;
|
|
|
|
let encoder = command_buffer.new_compute_command_encoder();
|
|
encoder.wait_for_fence(&kernels.fence);
|
|
encoder.set_compute_pipeline_state(&pipeline);
|
|
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
|
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
|
encoder.set_buffer(2, Some(output), 0);
|
|
// TODO Tensor D
|
|
|
|
let grid_z = b;
|
|
if batched {
|
|
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
|
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
|
let byte_stride_c = m * n * bytes as usize;
|
|
// TODO byte_stride_d
|
|
let byte_stride_d = 0;
|
|
|
|
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
|
|
for i in 0..b {
|
|
buffer.push((i * byte_stride_a) as u64);
|
|
buffer.push((i * byte_stride_b) as u64);
|
|
buffer.push((i * byte_stride_c) as u64);
|
|
buffer.push((i * byte_stride_d) as u64);
|
|
}
|
|
encoder.set_bytes(
|
|
10,
|
|
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
|
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
|
);
|
|
}
|
|
|
|
let grid_size = MTLSize {
|
|
width: divide(n, n_group.into()),
|
|
height: divide(m, m_group.into()),
|
|
depth: grid_z as NSUInteger,
|
|
};
|
|
let group_size = MTLSize {
|
|
width: 32 * (m_splits as u64) * (n_splits as u64),
|
|
height: 1,
|
|
depth: 1,
|
|
};
|
|
// println!("grid size {grid_size:?} group size {group_size:?}");
|
|
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
encoder.dispatch_thread_groups(grid_size, group_size);
|
|
encoder.update_fence(&kernels.fence);
|
|
encoder.end_encoding();
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn divide(m: usize, b: usize) -> NSUInteger {
|
|
((m + b - 1) / b) as NSUInteger
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests;
|