Compare commits

..

3 Commits

Author SHA1 Message Date
1f23cea90c MFA 2023-12-13 16:09:20 +01:00
ce33d6ad2a Tmp. 2023-12-11 11:10:48 +01:00
3d0ade406a Tmp. 2023-12-11 09:38:25 +01:00
6 changed files with 568 additions and 459 deletions

View File

@ -795,80 +795,38 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout, rhs_l: &Layout,
) -> Result<Self> { ) -> Result<Self> {
// Create descriptors // Create descriptors
let (type_id, size, name) = match self.dtype {
DType::F32 => (
metal::mps::MPS_FLOATBIT_ENCODING | 32,
core::mem::size_of::<f32>() as NSUInteger,
"sgemm",
),
DType::F16 => (
metal::mps::MPS_FLOATBIT_ENCODING | 16,
core::mem::size_of::<f16>() as NSUInteger,
"hgemm",
),
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
};
let lhs_stride = lhs_l.stride(); let buffer = self.device.new_buffer(b * m * n, self.dtype);
let rhs_stride = rhs_l.stride(); let name = match self.dtype {
let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; DType::F32 => "sgemm",
let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; DType::F16 => "hgemm",
let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; dtype => {
let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
// The a tensor has dims batching, k, n (rhs) }
let transpose_left = if lhs_m1 == 1 && lhs_m2 == k {
false
} else if lhs_m1 == m && lhs_m2 == 1 {
true
} else {
Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
}; };
let transpose_right = if rhs_m1 == 1 && rhs_m2 == n {
false
} else if rhs_m1 == k && rhs_m2 == 1 {
true
} else {
Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
let result_buffer = self.device.new_buffer(b * m * n, self.dtype);
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
command_buffer.set_label("matmul");
command_buffer.set_label("mfa gemm"); candle_metal_kernels::call_gemm(
candle_metal_kernels::call_mfa_gemm(
&self.device.device, &self.device.device,
&command_buffer, &command_buffer,
&self.device.kernels, &self.device.kernels,
name, name,
&self.buffer,
lhs_l.shape().dims(),
&rhs.buffer,
rhs_l.shape().dims(),
&result_buffer,
(b, m, n, k), (b, m, n, k),
transpose_left, &lhs_l.stride(),
transpose_right, lhs_l.start_offset(),
&self.buffer,
&rhs_l.stride(),
rhs_l.start_offset(),
&rhs.buffer,
&buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
// Create kernel
drop(command_buffer); drop(command_buffer);
self.device.commit(); self.device.commit();
Ok(Self::new( Ok(Self::new(buffer, self.device.clone(), self.dtype()))
self.buffer.clone(),
self.device.clone(),
self.dtype(),
))
} }
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {

View File

@ -11,7 +11,6 @@ license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
metal-flash-attention = { path = "../../../metal-flash-attention" }
once_cell = "1.18.0" once_cell = "1.18.0"
thiserror = "1" thiserror = "1"
tracing = "0.1.37" tracing = "0.1.37"

View File

@ -1,12 +1,9 @@
use metal::{ use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceUsage, MTLSize, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
NSUInteger,
}; };
use std::collections::{BTreeMap, HashMap}; use std::collections::HashMap;
use std::ffi::c_void; use std::ffi::c_void;
use std::hash::Hash;
use std::io::{stdout, Write};
use std::sync::RwLock; use std::sync::RwLock;
const AFFINE: &str = include_str!("affine.metal"); const AFFINE: &str = include_str!("affine.metal");
@ -16,7 +13,7 @@ const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal"); const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal"); const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal"); const REDUCE: &str = include_str!("reduce.metal");
const MFA_LIB: &[u8] = include_bytes!("mfa.metallib"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64; let size = length as u64;
@ -109,7 +106,7 @@ pub enum Source {
Ternary, Ternary,
Cast, Cast,
Reduce, Reduce,
MetalFlashAttention, Mfa,
} }
macro_rules! ops{ macro_rules! ops{
@ -184,9 +181,88 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
} }
} }
type KernelMap<T> = HashMap<KernelKey, T>; #[derive(Debug, PartialEq)]
pub enum Value {
USize(usize),
Bool(bool),
F32(f32),
U16(u16),
}
impl std::hash::Hash for Value {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
Value::F32(v) => v.to_bits().hash(state),
Value::USize(v) => v.hash(state),
Value::U16(v) => v.hash(state),
Value::Bool(v) => v.hash(state),
}
}
}
impl Value {
fn data_type(&self) -> MTLDataType {
match self {
Value::USize(_) => MTLDataType::UInt,
Value::F32(_) => MTLDataType::Float,
Value::U16(_) => MTLDataType::UShort,
Value::Bool(_) => MTLDataType::Bool,
}
}
}
/// Not true, good enough for our purposes.
impl Eq for Value {}
#[derive(Debug, Eq, PartialEq, Hash)]
struct ConstantValues(Vec<(usize, Value)>);
impl ConstantValues {
pub fn new(values: Vec<(usize, Value)>) -> Self {
Self(values)
}
fn function_constant_values(&self) -> FunctionConstantValues {
let f = FunctionConstantValues::new();
for (index, value) in &self.0 {
let ty = value.data_type();
match value {
Value::USize(v) => {
f.set_constant_value_at_index(
v as *const usize as *const c_void,
ty,
*index as u64,
);
}
Value::F32(v) => {
f.set_constant_value_at_index(
v as *const f32 as *const c_void,
ty,
*index as u64,
);
}
Value::U16(v) => {
f.set_constant_value_at_index(
v as *const u16 as *const c_void,
ty,
*index as u64,
);
}
Value::Bool(v) => {
f.set_constant_value_at_index(
v as *const bool as *const c_void,
ty,
*index as u64,
);
}
}
}
f
}
}
type Libraries = HashMap<Source, Library>; type Libraries = HashMap<Source, Library>;
type Pipelines = KernelMap<ComputePipelineState>; type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Kernels { pub struct Kernels {
@ -194,22 +270,6 @@ pub struct Kernels {
pipelines: RwLock<Pipelines>, pipelines: RwLock<Pipelines>,
} }
enum LibraryDefinition {
Source(&'static str),
Data(&'static [u8]),
}
impl From<&'static str> for LibraryDefinition {
fn from(s: &'static str) -> Self {
Self::Source(s)
}
}
impl From<&'static [u8]> for LibraryDefinition {
fn from(s: &'static [u8]) -> Self {
Self::Data(s)
}
}
impl Kernels { impl Kernels {
pub fn new() -> Self { pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new()); let libraries = RwLock::new(Libraries::new());
@ -220,16 +280,16 @@ impl Kernels {
} }
} }
fn get_library_source(&self, source: Source) -> LibraryDefinition { fn get_library_source(&self, source: Source) -> &'static str {
match source { match source {
Source::Affine => AFFINE.into(), Source::Affine => AFFINE,
Source::Unary => UNARY.into(), Source::Unary => UNARY,
Source::Binary => BINARY.into(), Source::Binary => BINARY,
Source::Ternary => TERNARY.into(), Source::Ternary => TERNARY,
Source::Indexing => INDEXING.into(), Source::Indexing => INDEXING,
Source::Cast => CAST.into(), Source::Cast => CAST,
Source::Reduce => REDUCE.into(), Source::Reduce => REDUCE,
Source::MetalFlashAttention => MFA_LIB.into(), Source::Mfa => unimplemented!("Mfa is not a source"),
} }
} }
@ -242,15 +302,20 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) { if let Some(lib) = libraries.get(&source) {
Ok(lib.clone()) Ok(lib.clone())
} else { } else {
let lib = match self.get_library_source(source) { let lib = match source {
LibraryDefinition::Source(source_content) => device Source::Mfa => {
let source_data = MFA;
device
.new_library_with_data(source_data)
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
}
source => {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new()) .new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?, .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
LibraryDefinition::Data(data) => device }
.new_library_with_data(data)
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?,
}; };
libraries.insert(source, lib.clone()); libraries.insert(source, lib.clone());
Ok(lib) Ok(lib)
} }
@ -260,187 +325,51 @@ impl Kernels {
&self, &self,
device: &Device, device: &Device,
source: Source, source: Source,
key: KernelKey, name: &'static str,
constants: Option<FunctionConstantValues>,
) -> Result<Function, MetalKernelError> { ) -> Result<Function, MetalKernelError> {
let func = self let func = self
.load_library(device, source)? .load_library(device, source)?
.get_function( .get_function(name, constants)
key.name,
key.constants.map(|c| c.create_function_constant_values()),
)
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
Ok(func) Ok(func)
} }
pub fn load_pipeline<T: Into<KernelKey>>( fn load_pipeline_with_constants(
&self, &self,
device: &Device, device: &Device,
source: Source, source: Source,
key: T, name: &'static str,
constants: Option<ConstantValues>,
) -> Result<ComputePipelineState, MetalKernelError> { ) -> Result<ComputePipelineState, MetalKernelError> {
let key: KernelKey = key.into();
let mut pipelines = self.pipelines.write()?; let mut pipelines = self.pipelines.write()?;
let key = (name, constants);
if let Some(pipeline) = pipelines.get(&key) { if let Some(pipeline) = pipelines.get(&key) {
Ok(pipeline.clone()) Ok(pipeline.clone())
} else { } else {
let func = self.load_function(device, source, key.clone())?; let (name, constants) = key;
let func = self.load_function(
device,
source,
name,
constants.as_ref().map(|c| c.function_constant_values()),
)?;
let pipeline = device let pipeline = device
.new_compute_pipeline_state_with_function(&func) .new_compute_pipeline_state_with_function(&func)
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
pipelines.insert(key, pipeline.clone()); pipelines.insert((name, constants), pipeline.clone());
Ok(pipeline) Ok(pipeline)
} }
} }
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub fn load_pipeline(
struct KernelKey { &self,
device: &Device,
source: Source,
name: &'static str, name: &'static str,
constants: Option<ConstantValues>, ) -> Result<ComputePipelineState, MetalKernelError> {
} self.load_pipeline_with_constants(device, source, name, None)
impl KernelKey {
fn new(name: &'static str) -> Self {
Self {
name,
constants: None,
}
}
fn with_constants(mut self, constants: ConstantValues) -> Self {
self.constants = Some(constants);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum ConstantValueId {
Index(NSUInteger),
Name(&'static str),
}
trait MetalDType {
const MTL_DATA_TYPE: MTLDataType;
}
macro_rules! metal_dtype {
($ty:ty, $mtl_data_type:ident) => {
impl MetalDType for $ty {
const MTL_DATA_TYPE: MTLDataType = MTLDataType::$mtl_data_type;
}
};
}
metal_dtype!(f32, Float);
metal_dtype!(u32, UInt);
metal_dtype!(u16, UShort);
metal_dtype!(bool, Bool);
#[derive(Debug, Clone, PartialEq)]
enum ConstantValueType {
Float(f32),
Uint(u32),
UShort(u16),
Bool(bool),
}
impl Hash for ConstantValueType {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
use ConstantValueType::*;
match self {
Float(v) => v.to_bits().hash(state),
Uint(v) => v.hash(state),
UShort(v) => v.hash(state),
Bool(v) => v.hash(state),
}
}
}
impl Eq for ConstantValueType {}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ConstantValues(BTreeMap<ConstantValueId, ConstantValueType>);
macro_rules! add_indexed_constant {
($fcv:expr, $value:expr, $ty:ty, $idx:expr) => {
$fcv.set_constant_value_at_index(
$value as *const $ty as *const c_void,
<$ty>::MTL_DATA_TYPE,
$idx,
)
};
}
macro_rules! add_named_constant {
($fcv:expr, $value:expr, $ty:ty, $name:expr) => {
$fcv.set_constant_value_with_name(
$value as *const $ty as *const c_void,
<$ty>::MTL_DATA_TYPE,
$name,
)
};
}
impl Hash for ConstantValues {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
for (id, value) in &self.0 {
id.hash(state);
value.hash(state);
}
}
}
impl ConstantValues {
fn new() -> Self {
Self(BTreeMap::new())
}
fn set(mut self, id: impl Into<ConstantValueId>, value: impl Into<ConstantValueType>) -> Self {
self.0.insert(id.into(), value.into());
self
}
fn create_function_constant_values(&self) -> FunctionConstantValues {
use ConstantValueId::*;
use ConstantValueType::*;
let mut function_values = FunctionConstantValues::new();
for (id, value) in &self.0 {
match (&id, &value) {
(Index(index), Float(value)) => {
add_indexed_constant!(function_values, value, f32, *index);
}
(Index(index), Uint(value)) => {
add_indexed_constant!(function_values, value, u32, *index);
}
(Index(index), UShort(value)) => {
add_indexed_constant!(function_values, value, u16, *index);
}
(Index(index), Bool(value)) => {
add_indexed_constant!(function_values, value, bool, *index);
}
(Name(name), Float(value)) => {
add_named_constant!(function_values, value, f32, name);
}
(Name(name), Uint(value)) => {
add_named_constant!(function_values, value, u32, name);
}
(Name(name), UShort(value)) => {
add_named_constant!(function_values, value, u16, name);
}
(Name(name), Bool(value)) => {
add_named_constant!(function_values, value, bool, name);
}
}
}
function_values
}
}
impl From<&'static str> for KernelKey {
fn from(name: &'static str) -> Self {
Self {
name,
constants: None,
}
} }
} }
@ -880,230 +809,169 @@ pub fn call_index_select(
Ok(()) Ok(())
} }
impl From<NSUInteger> for ConstantValueId {
fn from(idx: NSUInteger) -> Self {
Self::Index(idx)
}
}
impl From<usize> for ConstantValueId {
fn from(idx: usize) -> Self {
ConstantValueId::from(idx as NSUInteger)
}
}
impl From<i32> for ConstantValueId {
fn from(idx: i32) -> Self {
ConstantValueId::from(idx as NSUInteger)
}
}
impl From<&'static str> for ConstantValueId {
fn from(name: &'static str) -> Self {
Self::Name(name)
}
}
macro_rules! to_constant_value {
($ty:ty, $constant_value_type:ident) => {
to_constant_value!($ty, $ty, $constant_value_type);
};
($ty:ty, $via:ty, $constant_value_type:ident) => {
impl From<$ty> for ConstantValueType {
fn from(v: $ty) -> Self {
Self::$constant_value_type(v as $via)
}
}
};
}
to_constant_value!(f32, Float);
to_constant_value!(u32, Uint);
to_constant_value!(usize, u32, Uint);
to_constant_value!(u16, UShort);
to_constant_value!(bool, Bool);
struct MFAGemmConfig {
m: usize,
k: usize,
n: usize,
transpose_left: bool,
transpose_right: bool,
batched: bool,
m_simd: u16,
n_simd: u16,
k_simd: u16,
m_splits: u16,
n_splits: u16,
m_group: u16,
n_group: u16,
}
impl From<MFAGemmConfig> for ConstantValues {
fn from(conf: MFAGemmConfig) -> Self {
ConstantValues::new()
.set(0, conf.m)
.set(1, conf.k)
.set(2, conf.n)
.set(10, conf.transpose_left)
.set(11, conf.transpose_right)
.set(12, false)
.set(20, 1.0)
.set(21, 0.0)
.set(100, conf.batched)
.set(101, false)
.set(50001, false)
.set(200, conf.m_simd)
.set(201, conf.n_simd)
.set(202, conf.k_simd)
.set(210, conf.m_splits)
.set(211, conf.n_splits)
// garbage
.set(102, false)
.set(103, false)
.set(113, false)
.set(50000, false)
}
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn call_mfa_gemm( pub fn call_gemm(
device: &Device, device: &Device,
command_buffer: &CommandBufferRef, command_buffer: &CommandBufferRef,
kernels: &Kernels, kernels: &Kernels,
name: &'static str, name: &'static str,
lhs: &Buffer,
lhs_dims: &[usize],
rhs: &Buffer,
rhs_dims: &[usize],
output: &Buffer,
(b, m, n, k): (usize, usize, usize, usize), (b, m, n, k): (usize, usize, usize, usize),
transpose_left: bool, lhs_stride: &[usize],
transpose_right: bool, lhs_offset: usize,
lhs_buffer: &Buffer,
rhs_stride: &[usize],
rhs_offset: usize,
rhs_buffer: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
assert!(rhs_stride.len() >= 2);
assert!(lhs_stride.len() >= 2);
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
false
} else if lhs_m1 == m && lhs_m2 == 1 {
true
} else {
todo!();
// Err(MetalError::MatMulNonContiguous {
// lhs_stride: lhs_stride.to_vec(),
// rhs_stride: rhs_stride.to_vec(),
// mnk: (m, n, k),
// })?
};
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
false
} else if rhs_m1 == k && rhs_m2 == 1 {
true
} else {
todo!();
// Err(MetalError::MatMulNonContiguous {
// lhs_stride: lhs_stride.to_vec(),
// rhs_stride: rhs_stride.to_vec(),
// mnk: (m, n, k),
// })?
};
let d_trans = false;
let alpha = 1.0f32;
let beta = 0.0f32;
let batched = b > 1; let batched = b > 1;
let fused_activation = false;
let mut c_elements = m * n; let fused_bias = false;
if batched { let m_simd = 16;
c_elements *= 2; let n_simd = 16;
} let k_simd = 16;
let is_half = name == "hgemm";
let is_float = name == "sgemm";
let mut m_group = 32;
let mut n_group = 32;
let mut k_simd = 32;
if c_elements > 10 ^ 6 {
m_group = 48;
n_group = 48;
}
// If K_simd is perfectly equal to matrix K, the compiler can elide a large
// amount of logic in the kernel.
if k >= 33 && k <= 40 {
k_simd = 40;
} else if is_half && k >= 73 && k >= 80 {
k_simd = 80;
} else if c_elements > 10 ^ 6 {
if k <= 16 {
k_simd = 16;
} else if k <= 24 {
k_simd = 24;
} else if k <= 32 {
k_simd = 32;
} else if k <= 48 {
k_simd = 24;
} else if k <= 64 {
k_simd = 32;
} else if is_float {
k_simd = 24;
}
}
let m_splits = 2; let m_splits = 2;
let n_splits = 2; let n_splits = 2;
let m_simd = m_group / m_splits; let constants = Some(ConstantValues::new(vec![
let n_simd = n_group / n_splits; (0, Value::USize(m)),
(1, Value::USize(n)),
(2, Value::USize(k)),
(10, Value::Bool(a_trans)),
(11, Value::Bool(b_trans)),
(13, Value::Bool(d_trans)),
(20, Value::F32(alpha)),
(21, Value::F32(beta)),
(100, Value::Bool(batched)),
(101, Value::Bool(fused_activation)),
// Garbage
(102, Value::Bool(false)),
(103, Value::Bool(false)),
(113, Value::Bool(false)),
(50_000, Value::Bool(false)),
// End garbage
(200, Value::U16(m_simd)),
(201, Value::U16(n_simd)),
(202, Value::U16(k_simd)),
(210, Value::U16(m_splits)),
(211, Value::U16(n_splits)),
(50_001, Value::Bool(fused_bias)),
]));
// println!("Constants {constants:?}");
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
let m_group = m_simd * m_splits;
let n_group = n_simd * n_splits;
let config = MFAGemmConfig { let a_block_length = m_group * k_simd;
m, let b_block_length = k_simd * n_group;
k,
n,
transpose_left,
transpose_right,
batched,
m_simd,
n_simd,
k_simd,
m_splits,
n_splits,
m_group,
n_group,
};
let pipeline = kernels.load_pipeline( let mut block_elements = a_block_length + b_block_length;
device, if (m % 8 != 0) && (n % 8 != 0) {
Source::MetalFlashAttention, let c_block_length = m_group * n_group;
KernelKey::new(name).with_constants(config.into()), block_elements = std::cmp::max(c_block_length, block_elements)
)?;
let block_type_size = if is_half { 2 } else { 4 };
let a_block_bytes = m_group * k_simd * block_type_size;
let b_block_bytes = k_simd * n_group * block_type_size;
let c_block_bytes = m_group * n_group * block_type_size;
let mut thread_group_memory_length = a_block_bytes + b_block_bytes;
if m % 8 > 0 && n % 8 > 0 {
thread_group_memory_length = core::cmp::max(thread_group_memory_length, c_block_bytes);
} }
if fused_bias {
if d_trans {
block_elements = std::cmp::max(block_elements, m_group);
} else {
block_elements = std::cmp::max(block_elements, n_group);
}
}
// TODO adapt for f16
let bytes = match name {
"sgemm" => 4,
"hgemm" => 2,
other => {
return Err(MetalKernelError::LoadLibraryError(format!(
"{other} is not a valid kernel for gemm"
)));
}
};
let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger); // println!("Threadgroup {block_bytes}");
encoder.use_resources(&[&lhs, &rhs], MTLResourceUsage::Read); encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.use_resource(&output, MTLResourceUsage::Write); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffers(0, &[Some(lhs), Some(rhs), Some(output)], &[0; 3]); encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(2, Some(output), 0);
let ceil_divide = |a, b| (a + b - 1) / b; // TODO Tensor D
let mut grid_z = 1;
let grid_z = b;
if batched { if batched {
grid_z = lhs_dims[..lhs_dims.len() - 2].iter().product(); let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
let byte_stride = |shape: &[usize]| -> u64 { let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
let rank = shape.len(); let byte_stride_c = m * n * bytes as usize;
let mut output = core::mem::size_of::<f32>() * shape[rank - 2] * shape[rank - 1]; // TODO byte_stride_d
if shape[..shape.len() - 2].iter().product::<usize>() == 1 { let byte_stride_d = 0;
output = 0;
}
output as u64
};
let byte_stride_a = byte_stride(lhs_dims);
let byte_stride_b = byte_stride(rhs_dims);
let byte_stride_c = byte_stride(&[m, n]);
type BatchConfig = (u64, u64, u64, u64); let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
let mut batching_conf: Vec<BatchConfig> = vec![]; for i in 0..b {
for i in 0..grid_z { buffer.push((i * byte_stride_a) as u64);
batching_conf.push(( buffer.push((i * byte_stride_b) as u64);
i as u64 * byte_stride_a, buffer.push((i * byte_stride_c) as u64);
i as u64 * byte_stride_b, buffer.push((i * byte_stride_d) as u64);
i as u64 * byte_stride_c,
0u64,
));
} }
set_param(encoder, 10, batching_conf.as_slice()); encoder.set_bytes(
} 10,
buffer.len() as NSUInteger * core::mem::size_of::<u64>(),
let grid_size = MTLSize::new( buffer.as_ptr() as *const NSUInteger as *const c_void,
ceil_divide(n as NSUInteger, n_group as NSUInteger),
ceil_divide(m as NSUInteger, m_group as NSUInteger),
grid_z as NSUInteger,
); );
}
let group_size = MTLSize::new((32 * m_splits * n_splits) as NSUInteger, 1, 1); let grid_size = MTLSize {
width: divide(n, n_group.into()),
height: divide(m, m_group.into()),
depth: grid_z as NSUInteger,
};
let group_size = MTLSize {
width: 32 * (m_splits as u64) * (n_splits as u64),
height: 1,
depth: 1,
};
// println!("grid size {grid_size:?} group size {group_size:?}");
encoder.dispatch_thread_groups(grid_size, group_size); encoder.dispatch_thread_groups(grid_size, group_size);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;

View File

@ -0,0 +1,211 @@
import Metal
import MetalPerformanceShadersGraph
let type = MTLDataType.float;
let dataType = type;
var B = 2;
var M = 2;
var N = 4;
var K = 3;
var A_trans = false;
var B_trans = false;
var D_trans = false;
var alpha = Float(1.0);
var beta = Float(0.0);
var batched = B > 1;
var fused_activation = false;
var fused_bias = false;
let constants = MTLFunctionConstantValues()
constants.setConstantValue(&M, type: .uint, index: 0)
constants.setConstantValue(&N, type: .uint, index: 1)
constants.setConstantValue(&K, type: .uint, index: 2)
constants.setConstantValue(&A_trans, type: .bool, index: 10)
constants.setConstantValue(&B_trans, type: .bool, index: 11)
constants.setConstantValue(&D_trans, type: .bool, index: 13)
constants.setConstantValue(&alpha, type: .float, index: 20)
constants.setConstantValue(&beta, type: .float, index: 21)
constants.setConstantValue(&batched, type: .bool, index: 100)
constants.setConstantValue(&fused_activation, type: .bool, index: 101)
constants.setConstantValue(&fused_bias, type: .bool, index: 50001)
var M_simd = UInt16(16)
var N_simd = UInt16(16)
var K_simd = UInt16(32)
var M_splits = UInt16(2)
var N_splits = UInt16(2)
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
constants.setConstantValue(&K_simd, type: .ushort, index: 202)
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
constants.setConstantValue(&N_splits, type: .ushort, index: 211)
let M_group = M_simd * M_splits
let N_group = N_simd * N_splits
// Satisfy Metal API validation.
#if DEBUG
do {
var garbage: SIMD4<UInt64> = .zero
constants.setConstantValue(&garbage, type: .bool, index: 102)
constants.setConstantValue(&garbage, type: .bool, index: 103)
constants.setConstantValue(&garbage, type: .bool, index: 113)
constants.setConstantValue(&garbage, type: .bool, index: 50000)
}
#endif
print(constants)
let device = MTLCopyAllDevices().first!
device.shouldMaximizeConcurrentCompilation = true
var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!;
libraryURL.append(component: "src")
libraryURL.append(component: "libMetalFlashAttention.metallib")
let library = try! device.makeLibrary(URL: libraryURL)
var name: String
switch dataType {
case .half: name = "hgemm"
case .float: name = "sgemm"
default: fatalError()
}
let function = try! library.makeFunction(
name: name, constantValues: constants)
let A_block_length = M_group * K_simd
let B_block_length = K_simd * N_group
var blockElements = A_block_length + B_block_length;
if (M % 8 != 0) && (N % 8 != 0) {
let C_block_length = M_group * N_group;
blockElements = max(C_block_length, blockElements)
}
if fused_bias {
if D_trans {
blockElements = max(blockElements, M_group)
} else {
blockElements = max(blockElements, N_group)
}
}
// let blockBytes = blockElements * UInt16(dataType.size)
let elementSize = 4
let blockBytes = blockElements * UInt16(elementSize)
func ceilDivide(target: Int, granularity: UInt16) -> Int {
(target + Int(granularity) - 1) / Int(granularity)
}
var gridSize = MTLSize(
width: ceilDivide(target: N, granularity: N_group),
height: ceilDivide(target: M, granularity: M_group),
depth: 1)
let groupSize = MTLSize(
width: Int(32 * M_splits * N_splits),
height: 1,
depth: 1)
let commandQueue = device.makeCommandQueue()!
let commandBuffer = commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)!
let pipeline = try device.makeComputePipelineState(function: function)
let threadgroupMemoryLength = blockBytes;
print(threadgroupMemoryLength)
encoder.setComputePipelineState(pipeline)
encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0)
let rowsA = M;
let columnsA = K;
let rowsB = K;
let columnsB = N;
let rowsC = M;
let columnsC = N;
var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA)
var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB)
var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC)
for i in 0..<arrayA.count {
arrayA[i] = Float(i)
}
for i in 0..<arrayB.count {
arrayB[i] = Float(i)
}
let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])
let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])
let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])
print(arrayA)
print(arrayB)
encoder.setBuffer(bufferA, offset: 0, index: 0)
encoder.setBuffer(bufferB, offset: 0, index: 1)
encoder.setBuffer(bufferC, offset: 0, index: 2)
var gridZ: Int = B
if batched{
func byteStride(shape: [Int]) -> Int {
let rank = shape.count
var output = elementSize * shape[rank - 2] * shape[rank - 1]
if shape.dropLast(2).reduce(1, *) == 1 {
output = 0
}
return output
}
let byteStrideA = M*K*elementSize
let byteStrideB = N*K*elementSize
let byteStrideC = M*N*elementSize
let byteStrideD = 0
// if let shapeD = tensors.d?.shape {
// let rank = shapeD.count
// byteStrideD = elementSize * shapeD[rank - 1]
// if shapeD.dropLast(1).reduce(1, *) == 1 {
// byteStrideD = 0
// }
// }
withUnsafeTemporaryAllocation(
of: SIMD4<UInt64>.self, capacity: gridZ
) { buffer in
for i in 0..<buffer.count {
buffer[i] = SIMD4(
UInt64(truncatingIfNeeded: i * byteStrideA),
UInt64(truncatingIfNeeded: i * byteStrideB),
UInt64(truncatingIfNeeded: i * byteStrideC),
UInt64(truncatingIfNeeded: i * byteStrideD))
}
let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride
assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4)
encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
print("BATCHED")
print(buffer)
}
}
gridSize.depth = gridZ
print(gridSize, groupSize)
encoder.dispatchThreadgroups(
gridSize, threadsPerThreadgroup: groupSize
)
encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
var contents = bufferC!.contents();
var count = B * rowsA * columnsB;
var typedPointer = contents.bindMemory(to: Float.self, capacity: count)
var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
print(Array(bufferedPointer))

View File

@ -725,3 +725,76 @@ fn where_cond() {
); );
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
} }
fn run_gemm<T: Clone>(
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: Vec<usize>,
rhs: &[T],
rhs_stride: Vec<usize>,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(rhs) as u64,
options,
);
let length = b * m * n;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_gemm(
&device,
command_buffer,
&kernels,
"sgemm",
(b, m, n, k),
&lhs_stride,
0,
&lhs,
&rhs_stride,
0,
&rhs,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(length)
}
#[test]
fn gemm() {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride);
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
let (b, m, n, k) = (2, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride);
assert_eq!(
approx(results, 4),
vec![
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
518.0, 548.0, 578.0
]
);
}