mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Tmp.
This commit is contained in:
@ -1,6 +1,6 @@
|
|||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||||
Device, Function, Library, MTLSize,
|
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal");
|
|||||||
const TERNARY: &str = include_str!("ternary.metal");
|
const TERNARY: &str = include_str!("ternary.metal");
|
||||||
const CAST: &str = include_str!("cast.metal");
|
const CAST: &str = include_str!("cast.metal");
|
||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &str = include_str!("reduce.metal");
|
||||||
|
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||||
|
|
||||||
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||||
let size = length as u64;
|
let size = length as u64;
|
||||||
@ -105,6 +106,7 @@ pub enum Source {
|
|||||||
Ternary,
|
Ternary,
|
||||||
Cast,
|
Cast,
|
||||||
Reduce,
|
Reduce,
|
||||||
|
Mfa,
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! ops{
|
macro_rules! ops{
|
||||||
@ -179,9 +181,88 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type KernelMap<T> = HashMap<&'static str, T>;
|
#[derive(Debug, PartialEq)]
|
||||||
|
pub enum Value {
|
||||||
|
U32(u32),
|
||||||
|
Bool(bool),
|
||||||
|
F32(f32),
|
||||||
|
U16(u16),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::hash::Hash for Value {
|
||||||
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||||
|
match self {
|
||||||
|
Value::F32(v) => v.to_bits().hash(state),
|
||||||
|
Value::U32(v) => v.hash(state),
|
||||||
|
Value::U16(v) => v.hash(state),
|
||||||
|
Value::Bool(v) => v.hash(state),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Value {
|
||||||
|
fn data_type(&self) -> MTLDataType {
|
||||||
|
match self {
|
||||||
|
Value::U32(_) => MTLDataType::UInt,
|
||||||
|
Value::F32(_) => MTLDataType::Float,
|
||||||
|
Value::U16(_) => MTLDataType::UShort,
|
||||||
|
Value::Bool(_) => MTLDataType::Bool,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Not true, good enough for our purposes.
|
||||||
|
impl Eq for Value {}
|
||||||
|
|
||||||
|
#[derive(Debug, Eq, PartialEq, Hash)]
|
||||||
|
struct ConstantValues(Vec<(usize, Value)>);
|
||||||
|
|
||||||
|
impl ConstantValues {
|
||||||
|
pub fn new(values: Vec<(usize, Value)>) -> Self {
|
||||||
|
Self(values)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn function_constant_values(&self) -> FunctionConstantValues {
|
||||||
|
let f = FunctionConstantValues::new();
|
||||||
|
for (index, value) in &self.0 {
|
||||||
|
let ty = value.data_type();
|
||||||
|
match value {
|
||||||
|
Value::U32(v) => {
|
||||||
|
f.set_constant_value_at_index(
|
||||||
|
v as *const u32 as *const c_void,
|
||||||
|
ty,
|
||||||
|
*index as u64,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Value::F32(v) => {
|
||||||
|
f.set_constant_value_at_index(
|
||||||
|
v as *const f32 as *const c_void,
|
||||||
|
ty,
|
||||||
|
*index as u64,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Value::U16(v) => {
|
||||||
|
f.set_constant_value_at_index(
|
||||||
|
v as *const u16 as *const c_void,
|
||||||
|
ty,
|
||||||
|
*index as u64,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Value::Bool(v) => {
|
||||||
|
f.set_constant_value_at_index(
|
||||||
|
v as *const bool as *const c_void,
|
||||||
|
ty,
|
||||||
|
*index as u64,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Libraries = HashMap<Source, Library>;
|
type Libraries = HashMap<Source, Library>;
|
||||||
type Pipelines = KernelMap<ComputePipelineState>;
|
type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub struct Kernels {
|
pub struct Kernels {
|
||||||
@ -208,6 +289,7 @@ impl Kernels {
|
|||||||
Source::Indexing => INDEXING,
|
Source::Indexing => INDEXING,
|
||||||
Source::Cast => CAST,
|
Source::Cast => CAST,
|
||||||
Source::Reduce => REDUCE,
|
Source::Reduce => REDUCE,
|
||||||
|
Source::Mfa => unimplemented!("Mfa is not a source"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -220,10 +302,20 @@ impl Kernels {
|
|||||||
if let Some(lib) = libraries.get(&source) {
|
if let Some(lib) = libraries.get(&source) {
|
||||||
Ok(lib.clone())
|
Ok(lib.clone())
|
||||||
} else {
|
} else {
|
||||||
let source_content = self.get_library_source(source);
|
let lib = match source {
|
||||||
let lib = device
|
Source::Mfa => {
|
||||||
.new_library_with_source(source_content, &CompileOptions::new())
|
let source_data = MFA;
|
||||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
|
device
|
||||||
|
.new_library_with_data(source_data)
|
||||||
|
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||||
|
}
|
||||||
|
source => {
|
||||||
|
let source_content = self.get_library_source(source);
|
||||||
|
device
|
||||||
|
.new_library_with_source(source_content, &CompileOptions::new())
|
||||||
|
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||||
|
}
|
||||||
|
};
|
||||||
libraries.insert(source, lib.clone());
|
libraries.insert(source, lib.clone());
|
||||||
Ok(lib)
|
Ok(lib)
|
||||||
}
|
}
|
||||||
@ -234,19 +326,41 @@ impl Kernels {
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
source: Source,
|
source: Source,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
|
constants: Option<FunctionConstantValues>,
|
||||||
) -> Result<Function, MetalKernelError> {
|
) -> Result<Function, MetalKernelError> {
|
||||||
let func = self
|
let func = self
|
||||||
.load_library(device, source)?
|
.load_library(device, source)?
|
||||||
.get_function(name, None)
|
.get_function(name, constants)
|
||||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||||
Ok(func)
|
Ok(func)
|
||||||
// let mut funcs = self.funcs.write()?;
|
}
|
||||||
// if let Some(func) = funcs.get(name) {
|
|
||||||
// Ok(func.clone())
|
fn load_pipeline_with_constants(
|
||||||
// } else {
|
&self,
|
||||||
// funcs.insert(name, func.clone());
|
device: &Device,
|
||||||
// Ok(func)
|
source: Source,
|
||||||
// }
|
name: &'static str,
|
||||||
|
constants: Option<ConstantValues>,
|
||||||
|
) -> Result<ComputePipelineState, MetalKernelError> {
|
||||||
|
let mut pipelines = self.pipelines.write()?;
|
||||||
|
let key = (name, constants);
|
||||||
|
if let Some(pipeline) = pipelines.get(&key) {
|
||||||
|
Ok(pipeline.clone())
|
||||||
|
} else {
|
||||||
|
let (name, constants) = key;
|
||||||
|
let func = self.load_function(
|
||||||
|
device,
|
||||||
|
source,
|
||||||
|
name,
|
||||||
|
constants.as_ref().map(|c| c.function_constant_values()),
|
||||||
|
)?;
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(&func)
|
||||||
|
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
|
||||||
|
pipelines.insert((name, constants), pipeline.clone());
|
||||||
|
|
||||||
|
Ok(pipeline)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_pipeline(
|
pub fn load_pipeline(
|
||||||
@ -255,18 +369,7 @@ impl Kernels {
|
|||||||
source: Source,
|
source: Source,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
) -> Result<ComputePipelineState, MetalKernelError> {
|
) -> Result<ComputePipelineState, MetalKernelError> {
|
||||||
let mut pipelines = self.pipelines.write()?;
|
self.load_pipeline_with_constants(device, source, name, None)
|
||||||
if let Some(pipeline) = pipelines.get(name) {
|
|
||||||
Ok(pipeline.clone())
|
|
||||||
} else {
|
|
||||||
let func = self.load_function(device, source, name)?;
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(&func)
|
|
||||||
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
|
|
||||||
pipelines.insert(name, pipeline.clone());
|
|
||||||
|
|
||||||
Ok(pipeline)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -706,5 +809,130 @@ pub fn call_index_select(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_gemm(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
lhs_offset: usize,
|
||||||
|
lhs_buffer: &Buffer,
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
rhs_offset: usize,
|
||||||
|
rhs_buffer: &Buffer,
|
||||||
|
output: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let a_trans = false;
|
||||||
|
let b_trans = false;
|
||||||
|
let d_trans = false;
|
||||||
|
let alpha = 1.0;
|
||||||
|
let beta = 0.0;
|
||||||
|
let batched = b > 1;
|
||||||
|
let fused_activation = false;
|
||||||
|
let fused_bias = false;
|
||||||
|
let m_simd = 16;
|
||||||
|
let n_simd = 16;
|
||||||
|
let k_simd = 16;
|
||||||
|
let m_splits = 2;
|
||||||
|
let n_splits = 2;
|
||||||
|
let constants = Some(ConstantValues::new(vec![
|
||||||
|
(0, Value::U32(m as u32)),
|
||||||
|
(1, Value::U32(n as u32)),
|
||||||
|
(2, Value::U32(k as u32)),
|
||||||
|
(10, Value::Bool(a_trans)),
|
||||||
|
(11, Value::Bool(b_trans)),
|
||||||
|
(13, Value::Bool(d_trans)),
|
||||||
|
(20, Value::F32(alpha)),
|
||||||
|
(21, Value::F32(beta)),
|
||||||
|
(100, Value::Bool(batched)),
|
||||||
|
(101, Value::Bool(fused_activation)),
|
||||||
|
(200, Value::U16(m_simd)),
|
||||||
|
(201, Value::U16(n_simd)),
|
||||||
|
(202, Value::U16(k_simd)),
|
||||||
|
(210, Value::U16(m_splits)),
|
||||||
|
(211, Value::U16(n_splits)),
|
||||||
|
(50_001, Value::Bool(fused_bias)),
|
||||||
|
]));
|
||||||
|
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
||||||
|
let m_group = m_simd * m_splits;
|
||||||
|
let n_group = n_simd * n_splits;
|
||||||
|
|
||||||
|
let a_block_length = m_group * k_simd;
|
||||||
|
let b_block_length = k_simd * n_group;
|
||||||
|
|
||||||
|
let mut block_elements = a_block_length + b_block_length;
|
||||||
|
if (m % 8 != 0) && (n % 8 != 0) {
|
||||||
|
let c_block_length = m_group * n_group;
|
||||||
|
block_elements = std::cmp::max(c_block_length, block_elements)
|
||||||
|
}
|
||||||
|
if fused_bias {
|
||||||
|
if d_trans {
|
||||||
|
block_elements = std::cmp::max(block_elements, m_group);
|
||||||
|
} else {
|
||||||
|
block_elements = std::cmp::max(block_elements, n_group);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO adapt for f16
|
||||||
|
let bytes = match name {
|
||||||
|
"sgemm" => 4,
|
||||||
|
"hgemm" => 2,
|
||||||
|
other => {
|
||||||
|
return Err(MetalKernelError::LoadLibraryError(format!(
|
||||||
|
"{other} is not a valid kernel for gemm"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let block_bytes = block_elements * bytes;
|
||||||
|
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
encoder.set_threadgroup_memory_length(block_bytes.into(), 0);
|
||||||
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||||
|
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||||
|
encoder.set_buffer(2, Some(output), 0);
|
||||||
|
// TODO Tensor D
|
||||||
|
|
||||||
|
let grid_z = b;
|
||||||
|
let byte_stride_a: usize = *lhs_stride.get(lhs_stride.len() - 2).unwrap_or(&0);
|
||||||
|
let byte_stride_b = *rhs_stride.get(rhs_stride.len() - 2).unwrap_or(&0);
|
||||||
|
let byte_stride_c = m * n;
|
||||||
|
// TODO byte_stride_d
|
||||||
|
let byte_stride_d = 1;
|
||||||
|
|
||||||
|
let mut buffer = Vec::with_capacity(b * 4);
|
||||||
|
for i in 0..b {
|
||||||
|
buffer.push(i * byte_stride_a);
|
||||||
|
buffer.push(i * byte_stride_b);
|
||||||
|
buffer.push(i * byte_stride_c);
|
||||||
|
buffer.push(i * byte_stride_d);
|
||||||
|
}
|
||||||
|
encoder.set_bytes(
|
||||||
|
10,
|
||||||
|
buffer.len() as NSUInteger,
|
||||||
|
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||||
|
);
|
||||||
|
|
||||||
|
let grid_size = MTLSize {
|
||||||
|
width: divide(n, n_group.into()),
|
||||||
|
height: divide(m, m_group.into()),
|
||||||
|
depth: grid_z as NSUInteger,
|
||||||
|
};
|
||||||
|
let group_size = MTLSize {
|
||||||
|
width: 32 * (m_splits as u64) * (n_splits as u64),
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||||
|
((m + b - 1) / b) as NSUInteger
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
BIN
candle-metal-kernels/src/libMetalFlashAttention.metallib
Normal file
BIN
candle-metal-kernels/src/libMetalFlashAttention.metallib
Normal file
Binary file not shown.
@ -725,3 +725,66 @@ fn where_cond() {
|
|||||||
);
|
);
|
||||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn run_gemm<T: Clone>(
|
||||||
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
|
lhs: &[T],
|
||||||
|
lhs_stride: Vec<usize>,
|
||||||
|
rhs: &[T],
|
||||||
|
rhs_stride: Vec<usize>,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let device = device();
|
||||||
|
let kernels = Kernels::new();
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
|
||||||
|
let lhs = device.new_buffer_with_data(
|
||||||
|
lhs.as_ptr() as *const core::ffi::c_void,
|
||||||
|
std::mem::size_of_val(lhs) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
let rhs = device.new_buffer_with_data(
|
||||||
|
rhs.as_ptr() as *const core::ffi::c_void,
|
||||||
|
std::mem::size_of_val(rhs) as u64,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
let length = b * m * n;
|
||||||
|
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||||
|
call_gemm(
|
||||||
|
&device,
|
||||||
|
command_buffer,
|
||||||
|
&kernels,
|
||||||
|
"sgemm",
|
||||||
|
(b, m, n, k),
|
||||||
|
&lhs_stride,
|
||||||
|
0,
|
||||||
|
&lhs,
|
||||||
|
&rhs_stride,
|
||||||
|
0,
|
||||||
|
&rhs,
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
output.read_to_vec::<T>(length)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gemm() {
|
||||||
|
let (b, m, n, k) = (2, 2, 4, 3);
|
||||||
|
let lhs_stride = vec![m * k, k, 1];
|
||||||
|
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||||
|
let rhs_stride = vec![n * k, n, 1];
|
||||||
|
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||||
|
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride);
|
||||||
|
assert_eq!(
|
||||||
|
approx(results, 4),
|
||||||
|
vec![
|
||||||
|
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
|
||||||
|
518.0, 548.0, 578.0
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user