Compare commits

...

2 Commits

Author SHA1 Message Date
ce0783d9ff Stash for debugging 2023-12-10 13:11:53 +01:00
35352e441a Begin adding mfa support 2023-12-08 21:51:49 +01:00
4 changed files with 452 additions and 74 deletions

View File

@ -795,15 +795,16 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
// Create descriptors
let (type_id, size) = match self.dtype {
let (type_id, size, name) = match self.dtype {
DType::F32 => (
metal::mps::MPS_FLOATBIT_ENCODING | 32,
core::mem::size_of::<f32>() as NSUInteger,
"sgemm",
),
DType::F16 => (
metal::mps::MPS_FLOATBIT_ENCODING | 16,
core::mem::size_of::<f16>() as NSUInteger,
"hgemm",
),
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
};
@ -837,60 +838,37 @@ impl BackendStorage for MetalStorage {
mnk: (m, n, k),
})?
};
let b = b as NSUInteger;
let m = m as NSUInteger;
let n = n as NSUInteger;
let k = k as NSUInteger;
let left_matrix = self.matrix(
(b, m, k),
transpose_left,
size,
lhs_l.start_offset() as NSUInteger * size,
type_id,
)?;
let right_matrix = rhs.matrix(
(b, k, n),
transpose_right,
size,
rhs_l.start_offset() as NSUInteger * size,
type_id,
)?;
let (result_matrix, out_buffer) =
self.device
.new_matrix((b, m, n), size, type_id, self.dtype)?;
let result_buffer = self.device.new_buffer(b * m * n, self.dtype);
let command_buffer = self.device.command_buffer();
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
command_buffer.set_label("mfa gemm");
candle_metal_kernels::call_mfa_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
&self.buffer,
lhs_l.shape().dims(),
&rhs.buffer,
rhs_l.shape().dims(),
&result_buffer,
(b, m, n, k),
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
.map_err(MetalError::from)?;
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
command_buffer.set_label("matmul");
drop(command_buffer);
self.device.commit();
Ok(Self::new(out_buffer, self.device.clone(), self.dtype()))
Ok(Self::new(
self.buffer.clone(),
self.device.clone(),
self.dtype(),
))
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {

View File

@ -11,6 +11,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
metal-flash-attention = { path = "../../../metal-flash-attention" }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -1,9 +1,12 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, Library, MTLSize,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceUsage, MTLSize,
NSUInteger,
};
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};
use std::ffi::c_void;
use std::hash::Hash;
use std::io::{stdout, Write};
use std::sync::RwLock;
const AFFINE: &str = include_str!("affine.metal");
@ -13,6 +16,7 @@ const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
const MFA_LIB: &[u8] = include_bytes!("mfa.metallib");
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
@ -105,6 +109,7 @@ pub enum Source {
Ternary,
Cast,
Reduce,
MetalFlashAttention,
}
macro_rules! ops{
@ -179,7 +184,7 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
}
}
type KernelMap<T> = HashMap<&'static str, T>;
type KernelMap<T> = HashMap<KernelKey, T>;
type Libraries = HashMap<Source, Library>;
type Pipelines = KernelMap<ComputePipelineState>;
@ -189,6 +194,22 @@ pub struct Kernels {
pipelines: RwLock<Pipelines>,
}
enum LibraryDefinition {
Source(&'static str),
Data(&'static [u8]),
}
impl From<&'static str> for LibraryDefinition {
fn from(s: &'static str) -> Self {
Self::Source(s)
}
}
impl From<&'static [u8]> for LibraryDefinition {
fn from(s: &'static [u8]) -> Self {
Self::Data(s)
}
}
impl Kernels {
pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new());
@ -199,15 +220,16 @@ impl Kernels {
}
}
fn get_library_source(&self, source: Source) -> &'static str {
fn get_library_source(&self, source: Source) -> LibraryDefinition {
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
Source::Binary => BINARY,
Source::Ternary => TERNARY,
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Affine => AFFINE.into(),
Source::Unary => UNARY.into(),
Source::Binary => BINARY.into(),
Source::Ternary => TERNARY.into(),
Source::Indexing => INDEXING.into(),
Source::Cast => CAST.into(),
Source::Reduce => REDUCE.into(),
Source::MetalFlashAttention => MFA_LIB.into(),
}
}
@ -220,10 +242,15 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
let source_content = self.get_library_source(source);
let lib = device
let lib = match self.get_library_source(source) {
LibraryDefinition::Source(source_content) => device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?,
LibraryDefinition::Data(data) => device
.new_library_with_data(data)
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?,
};
libraries.insert(source, lib.clone());
Ok(lib)
}
@ -233,43 +260,190 @@ impl Kernels {
&self,
device: &Device,
source: Source,
name: &'static str,
key: KernelKey,
) -> Result<Function, MetalKernelError> {
let func = self
.load_library(device, source)?
.get_function(name, None)
.get_function(
key.name,
key.constants.map(|c| c.create_function_constant_values()),
)
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
Ok(func)
// let mut funcs = self.funcs.write()?;
// if let Some(func) = funcs.get(name) {
// Ok(func.clone())
// } else {
// funcs.insert(name, func.clone());
// Ok(func)
// }
}
pub fn load_pipeline(
pub fn load_pipeline<T: Into<KernelKey>>(
&self,
device: &Device,
source: Source,
name: &'static str,
key: T,
) -> Result<ComputePipelineState, MetalKernelError> {
let key: KernelKey = key.into();
let mut pipelines = self.pipelines.write()?;
if let Some(pipeline) = pipelines.get(name) {
if let Some(pipeline) = pipelines.get(&key) {
Ok(pipeline.clone())
} else {
let func = self.load_function(device, source, name)?;
let func = self.load_function(device, source, key.clone())?;
let pipeline = device
.new_compute_pipeline_state_with_function(&func)
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
pipelines.insert(name, pipeline.clone());
pipelines.insert(key, pipeline.clone());
Ok(pipeline)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct KernelKey {
name: &'static str,
constants: Option<ConstantValues>,
}
impl KernelKey {
fn new(name: &'static str) -> Self {
Self {
name,
constants: None,
}
}
fn with_constants(mut self, constants: ConstantValues) -> Self {
self.constants = Some(constants);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum ConstantValueId {
Index(NSUInteger),
Name(&'static str),
}
trait MetalDType {
const MTL_DATA_TYPE: MTLDataType;
}
macro_rules! metal_dtype {
($ty:ty, $mtl_data_type:ident) => {
impl MetalDType for $ty {
const MTL_DATA_TYPE: MTLDataType = MTLDataType::$mtl_data_type;
}
};
}
metal_dtype!(f32, Float);
metal_dtype!(u32, UInt);
metal_dtype!(u16, UShort);
metal_dtype!(bool, Bool);
#[derive(Debug, Clone, PartialEq)]
enum ConstantValueType {
Float(f32),
Uint(u32),
UShort(u16),
Bool(bool),
}
impl Hash for ConstantValueType {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
use ConstantValueType::*;
match self {
Float(v) => v.to_bits().hash(state),
Uint(v) => v.hash(state),
UShort(v) => v.hash(state),
Bool(v) => v.hash(state),
}
}
}
impl Eq for ConstantValueType {}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ConstantValues(BTreeMap<ConstantValueId, ConstantValueType>);
macro_rules! add_indexed_constant {
($fcv:expr, $value:expr, $ty:ty, $idx:expr) => {
$fcv.set_constant_value_at_index(
$value as *const $ty as *const c_void,
<$ty>::MTL_DATA_TYPE,
$idx,
)
};
}
macro_rules! add_named_constant {
($fcv:expr, $value:expr, $ty:ty, $name:expr) => {
$fcv.set_constant_value_with_name(
$value as *const $ty as *const c_void,
<$ty>::MTL_DATA_TYPE,
$name,
)
};
}
impl Hash for ConstantValues {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
for (id, value) in &self.0 {
id.hash(state);
value.hash(state);
}
}
}
impl ConstantValues {
fn new() -> Self {
Self(BTreeMap::new())
}
fn set(mut self, id: impl Into<ConstantValueId>, value: impl Into<ConstantValueType>) -> Self {
self.0.insert(id.into(), value.into());
self
}
fn create_function_constant_values(&self) -> FunctionConstantValues {
use ConstantValueId::*;
use ConstantValueType::*;
let mut function_values = FunctionConstantValues::new();
for (id, value) in &self.0 {
match (&id, &value) {
(Index(index), Float(value)) => {
add_indexed_constant!(function_values, value, f32, *index);
}
(Index(index), Uint(value)) => {
add_indexed_constant!(function_values, value, u32, *index);
}
(Index(index), UShort(value)) => {
add_indexed_constant!(function_values, value, u16, *index);
}
(Index(index), Bool(value)) => {
add_indexed_constant!(function_values, value, bool, *index);
}
(Name(name), Float(value)) => {
add_named_constant!(function_values, value, f32, name);
}
(Name(name), Uint(value)) => {
add_named_constant!(function_values, value, u32, name);
}
(Name(name), UShort(value)) => {
add_named_constant!(function_values, value, u16, name);
}
(Name(name), Bool(value)) => {
add_named_constant!(function_values, value, bool, name);
}
}
}
function_values
}
}
impl From<&'static str> for KernelKey {
fn from(name: &'static str) -> Self {
Self {
name,
constants: None,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous(
device: &Device,
@ -706,5 +880,230 @@ pub fn call_index_select(
Ok(())
}
impl From<NSUInteger> for ConstantValueId {
fn from(idx: NSUInteger) -> Self {
Self::Index(idx)
}
}
impl From<usize> for ConstantValueId {
fn from(idx: usize) -> Self {
ConstantValueId::from(idx as NSUInteger)
}
}
impl From<i32> for ConstantValueId {
fn from(idx: i32) -> Self {
ConstantValueId::from(idx as NSUInteger)
}
}
impl From<&'static str> for ConstantValueId {
fn from(name: &'static str) -> Self {
Self::Name(name)
}
}
macro_rules! to_constant_value {
($ty:ty, $constant_value_type:ident) => {
to_constant_value!($ty, $ty, $constant_value_type);
};
($ty:ty, $via:ty, $constant_value_type:ident) => {
impl From<$ty> for ConstantValueType {
fn from(v: $ty) -> Self {
Self::$constant_value_type(v as $via)
}
}
};
}
to_constant_value!(f32, Float);
to_constant_value!(u32, Uint);
to_constant_value!(usize, u32, Uint);
to_constant_value!(u16, UShort);
to_constant_value!(bool, Bool);
struct MFAGemmConfig {
m: usize,
k: usize,
n: usize,
transpose_left: bool,
transpose_right: bool,
batched: bool,
m_simd: u16,
n_simd: u16,
k_simd: u16,
m_splits: u16,
n_splits: u16,
m_group: u16,
n_group: u16,
}
impl From<MFAGemmConfig> for ConstantValues {
fn from(conf: MFAGemmConfig) -> Self {
ConstantValues::new()
.set(0, conf.m)
.set(1, conf.k)
.set(2, conf.n)
.set(10, conf.transpose_left)
.set(11, conf.transpose_right)
.set(12, false)
.set(20, 1.0)
.set(21, 0.0)
.set(100, conf.batched)
.set(101, false)
.set(50001, false)
.set(200, conf.m_simd)
.set(201, conf.n_simd)
.set(202, conf.k_simd)
.set(210, conf.m_splits)
.set(211, conf.n_splits)
// garbage
.set(102, false)
.set(103, false)
.set(113, false)
.set(50000, false)
}
}
#[allow(clippy::too_many_arguments)]
pub fn call_mfa_gemm(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
lhs: &Buffer,
lhs_dims: &[usize],
rhs: &Buffer,
rhs_dims: &[usize],
output: &Buffer,
(b, m, n, k): (usize, usize, usize, usize),
transpose_left: bool,
transpose_right: bool,
) -> Result<(), MetalKernelError> {
let batched = b > 1;
let mut c_elements = m * n;
if batched {
c_elements *= 2;
}
let is_half = name == "hgemm";
let is_float = name == "sgemm";
let mut m_group = 32;
let mut n_group = 32;
let mut k_simd = 32;
if c_elements > 10 ^ 6 {
m_group = 48;
n_group = 48;
}
// If K_simd is perfectly equal to matrix K, the compiler can elide a large
// amount of logic in the kernel.
if k >= 33 && k <= 40 {
k_simd = 40;
} else if is_half && k >= 73 && k >= 80 {
k_simd = 80;
} else if c_elements > 10 ^ 6 {
if k <= 16 {
k_simd = 16;
} else if k <= 24 {
k_simd = 24;
} else if k <= 32 {
k_simd = 32;
} else if k <= 48 {
k_simd = 24;
} else if k <= 64 {
k_simd = 32;
} else if is_float {
k_simd = 24;
}
}
let m_splits = 2;
let n_splits = 2;
let m_simd = m_group / m_splits;
let n_simd = n_group / n_splits;
let config = MFAGemmConfig {
m,
k,
n,
transpose_left,
transpose_right,
batched,
m_simd,
n_simd,
k_simd,
m_splits,
n_splits,
m_group,
n_group,
};
let pipeline = kernels.load_pipeline(
device,
Source::MetalFlashAttention,
KernelKey::new(name).with_constants(config.into()),
)?;
let block_type_size = if is_half { 2 } else { 4 };
let a_block_bytes = m_group * k_simd * block_type_size;
let b_block_bytes = k_simd * n_group * block_type_size;
let c_block_bytes = m_group * n_group * block_type_size;
let mut thread_group_memory_length = a_block_bytes + b_block_bytes;
if m % 8 > 0 && n % 8 > 0 {
thread_group_memory_length = core::cmp::max(thread_group_memory_length, c_block_bytes);
}
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger);
encoder.use_resources(&[&lhs, &rhs], MTLResourceUsage::Read);
encoder.use_resource(&output, MTLResourceUsage::Write);
encoder.set_buffers(0, &[Some(lhs), Some(rhs), Some(output)], &[0; 3]);
let ceil_divide = |a, b| (a + b - 1) / b;
let mut grid_z = 1;
if batched {
grid_z = lhs_dims[..lhs_dims.len() - 2].iter().product();
let byte_stride = |shape: &[usize]| -> u64 {
let rank = shape.len();
let mut output = core::mem::size_of::<f32>() * shape[rank - 2] * shape[rank - 1];
if shape[..shape.len() - 2].iter().product::<usize>() == 1 {
output = 0;
}
output as u64
};
let byte_stride_a = byte_stride(lhs_dims);
let byte_stride_b = byte_stride(rhs_dims);
let byte_stride_c = byte_stride(&[m, n]);
type BatchConfig = (u64, u64, u64, u64);
let mut batching_conf: Vec<BatchConfig> = vec![];
for i in 0..grid_z {
batching_conf.push((
i as u64 * byte_stride_a,
i as u64 * byte_stride_b,
i as u64 * byte_stride_c,
0u64,
));
}
set_param(encoder, 10, batching_conf.as_slice());
}
let grid_size = MTLSize::new(
ceil_divide(n as NSUInteger, n_group as NSUInteger),
ceil_divide(m as NSUInteger, m_group as NSUInteger),
grid_z as NSUInteger,
);
let group_size = MTLSize::new((32 * m_splits * n_splits) as NSUInteger, 1, 1);
encoder.dispatch_thread_groups(grid_size, group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests;

Binary file not shown.