Begin adding mfa support

This commit is contained in:
Ivar Flakstad
2023-12-08 21:51:49 +01:00
parent 2ca086939f
commit 35352e441a
4 changed files with 205 additions and 32 deletions

View File

@ -795,7 +795,6 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
// Create descriptors
let (type_id, size) = match self.dtype {
DType::F32 => (
metal::mps::MPS_FLOATBIT_ENCODING | 32,

View File

@ -11,6 +11,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
metal-flash-attention = { path = "../../../metal-flash-attention" }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -1,9 +1,7 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, Library, MTLSize,
};
use metal::{Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::hash::Hash;
use std::sync::RwLock;
const AFFINE: &str = include_str!("affine.metal");
@ -13,6 +11,7 @@ const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
const MFA_LIB: &[u8] = include_bytes!("mfa.metallib");
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
@ -105,6 +104,7 @@ pub enum Source {
Ternary,
Cast,
Reduce,
MetalFlashAttention,
}
macro_rules! ops{
@ -179,7 +179,7 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
}
}
type KernelMap<T> = HashMap<&'static str, T>;
type KernelMap<T> = HashMap<KernelKey, T>;
type Libraries = HashMap<Source, Library>;
type Pipelines = KernelMap<ComputePipelineState>;
@ -189,6 +189,22 @@ pub struct Kernels {
pipelines: RwLock<Pipelines>,
}
enum LibraryDefinition {
Source(&'static str),
Data(&'static [u8]),
}
impl From<&'static str> for LibraryDefinition {
fn from(s: &'static str) -> Self {
Self::Source(s)
}
}
impl From<&'static [u8]> for LibraryDefinition {
fn from(s: &'static [u8]) -> Self {
Self::Data(s)
}
}
impl Kernels {
pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new());
@ -199,15 +215,16 @@ impl Kernels {
}
}
fn get_library_source(&self, source: Source) -> &'static str {
fn get_library_source(&self, source: Source) -> LibraryDefinition {
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
Source::Binary => BINARY,
Source::Ternary => TERNARY,
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Affine => AFFINE.into(),
Source::Unary => UNARY.into(),
Source::Binary => BINARY.into(),
Source::Ternary => TERNARY.into(),
Source::Indexing => INDEXING.into(),
Source::Cast => CAST.into(),
Source::Reduce => REDUCE.into(),
Source::MetalFlashAttention => MFA_LIB.into(),
}
}
@ -220,10 +237,15 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
let source_content = self.get_library_source(source);
let lib = device
let lib = match self.get_library_source(source) {
LibraryDefinition::Source(source_content) => device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?,
LibraryDefinition::Data(data) => device
.new_library_with_data(data)
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?,
};
libraries.insert(source, lib.clone());
Ok(lib)
}
@ -233,43 +255,154 @@ impl Kernels {
&self,
device: &Device,
source: Source,
name: &'static str,
key: KernelKey,
) -> Result<Function, MetalKernelError> {
let func = self
.load_library(device, source)?
.get_function(name, None)
.get_function(key.name, key.constants.map(|c| c.create_function_constant_values()))
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
Ok(func)
// let mut funcs = self.funcs.write()?;
// if let Some(func) = funcs.get(name) {
// Ok(func.clone())
// } else {
// funcs.insert(name, func.clone());
// Ok(func)
// }
}
pub fn load_pipeline(
pub fn load_pipeline<T: Into<KernelKey>>(
&self,
device: &Device,
source: Source,
name: &'static str,
key: T,
) -> Result<ComputePipelineState, MetalKernelError> {
let key: KernelKey = key.into();
let mut pipelines = self.pipelines.write()?;
if let Some(pipeline) = pipelines.get(name) {
if let Some(pipeline) = pipelines.get(&key) {
Ok(pipeline.clone())
} else {
let func = self.load_function(device, source, name)?;
let func = self.load_function(device, source, key.clone())?;
let pipeline = device
.new_compute_pipeline_state_with_function(&func)
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
pipelines.insert(name, pipeline.clone());
pipelines.insert(key, pipeline.clone());
Ok(pipeline)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct KernelKey {
name: &'static str,
constants: Option<ConstantValues>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum ConstantValueId {
Index(NSUInteger),
Name(&'static str),
}
trait MetalDType {
const MTL_DATA_TYPE: MTLDataType;
}
macro_rules! metal_dtype {
($ty:ty, $mtl_data_type:ident) => {
impl MetalDType for $ty {
const MTL_DATA_TYPE: MTLDataType = MTLDataType::$mtl_data_type;
}
}
}
metal_dtype!(f32, Float);
metal_dtype!(u32, UInt);
metal_dtype!(u16, UShort);
metal_dtype!(bool, Bool);
#[derive(Debug, Clone, PartialEq)]
enum ConstantValue {
Float(f32),
Uint(u32),
UShort(u16),
Bool(bool),
}
impl Hash for ConstantValue {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
use ConstantValue::*;
match self {
Float(_) => {}, // do nothing
Uint(v) => v.hash(state),
UShort(v) => v.hash(state),
Bool(v) => v.hash(state),
}
}
}
impl Eq for ConstantValue {}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ConstantValues(Vec<(ConstantValueId, ConstantValue)>);
macro_rules! add_indexed_constant {
($fcv:expr, $value:expr, $ty:ty, $idx:expr) => {
$fcv.set_constant_value_at_index(
$value as *const $ty as *const c_void,
<$ty>::MTL_DATA_TYPE,
$idx,
)
};
}
macro_rules! add_named_constant {
($fcv:expr, $value:expr, $ty:ty, $name:expr) => {
$fcv.set_constant_value_with_name(
$value as *const $ty as *const c_void,
<$ty>::MTL_DATA_TYPE,
$name,
)
};
}
impl ConstantValues {
fn create_function_constant_values(&self) -> FunctionConstantValues {
use ConstantValueId::*;
use ConstantValue::*;
let mut function_values = FunctionConstantValues::new();
for (id, value) in &self.0 {
match (id, value) {
(Index(index), Float(value)) => {
add_indexed_constant!(function_values, value, f32, *index);
}
(Index(index), Uint(value)) => {
add_indexed_constant!(function_values, value, u32, *index);
}
(Index(index), UShort(value)) => {
add_indexed_constant!(function_values, value, u16, *index);
}
(Index(index), Bool(value)) => {
add_indexed_constant!(function_values, value, bool, *index);
}
(Name(name), Float(value)) => {
add_named_constant!(function_values, value, f32, name);
}
(Name(name), Uint(value)) => {
add_named_constant!(function_values, value, u32, name);
}
(Name(name), UShort(value)) => {
add_named_constant!(function_values, value, u16, name);
}
(Name(name), Bool(value)) => {
add_named_constant!(function_values, value, bool, name);
}
}
}
function_values
}
}
impl From<&'static str> for KernelKey {
fn from(name: &'static str) -> Self {
Self {
name,
constants: None,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous(
device: &Device,
@ -706,5 +839,45 @@ pub fn call_index_select(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_mfa_gemm(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: &Buffer,
strides: &[usize],
offset: usize,
output: &Buffer,
output_offset: usize,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::MetalFlashAttention, name)?;
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
set_params!(
encoder,
(
length,
num_dims,
shape,
strides,
(input, offset),
(output, output_offset)
)
);
let width: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests;

Binary file not shown.