Compare commits

..

4 Commits

Author SHA1 Message Date
c65f68e988 Tmp gemm. 2023-11-19 20:43:59 +01:00
eed1631ee2 Reuse buffers on our own reference counts. 2023-11-18 23:28:59 +01:00
251c65f9f1 Metal operational. 2023-11-18 00:52:38 +01:00
a0010898cc Better batched matmul. 2023-11-17 10:36:57 +01:00
11 changed files with 911 additions and 187 deletions

View File

@ -61,10 +61,7 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
dispatch = "0.2.0"
rustc-hash = "1.1"
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
[profile.release-with-debug]
inherits = "release"

View File

@ -30,8 +30,6 @@ safetensors = { workspace = true }
thiserror = { workspace = true }
yoke = { workspace = true }
zip = { workspace = true }
dispatch = { workspace = true, optional = true }
rustc-hash = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -43,4 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]
metal = ["dep:metal", "dep:candle-metal-kernels"]

View File

@ -9,8 +9,6 @@ use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::{Arc, RwLock};
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use dispatch::{Queue, QueueAttribute};
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@ -39,9 +37,8 @@ pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
buffers: Arc<RwLock<FxHashMap<usize, Vec<Buffer>>>>,
queue : Queue,
kernels: Arc<candle_metal_kernels::Kernels>,
buffers: Arc<RwLock<HashMap<usize, Vec<Arc<Buffer>>>>>,
}
impl std::fmt::Debug for MetalDevice {
@ -63,6 +60,10 @@ impl MetalDevice {
self.registry_id()
}
pub fn metal_device(&self) -> &metal::Device {
&self.device
}
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
@ -88,19 +89,21 @@ impl MetalDevice {
&self.device
}
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer >{
let size = element_count * dtype.size_in_bytes();
let mut buffers = self.buffers.try_write().unwrap();
let subbuffers = buffers.entry(size).or_insert(vec![]);
for sub in &mut *subbuffers{
if sub.retain_count() == 1{
// if sub.retain_count() == 1{
// println!("{size} {:?}", );
if Arc::strong_count(sub) == 1{
return sub.clone();
// println!("{size } {:?}", sub.retain_count());
}
}
let new_buffer = self.device
.new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
new_buffer
}
@ -122,7 +125,7 @@ impl MetalDevice {
#[derive(Debug, Clone)]
pub struct MetalStorage {
buffer: metal::Buffer,
buffer: Arc<metal::Buffer>,
device: MetalDevice,
dtype: DType,
}
@ -145,14 +148,13 @@ impl BackendStorage for MetalStorage {
fn to_cpu_storage(&self) -> Result<CpuStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length());
{
let command = self.device.command_buffer();
let blit = command.new_blit_command_encoder();
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
let command = self.device.command_buffer();
let blit = command.new_blit_command_encoder();
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
}
self.device.wait_until_completed();
match self.dtype {
@ -202,7 +204,7 @@ impl BackendStorage for MetalStorage {
name,
el,
&self.buffer,
&mut buffer,
&buffer,
mul as f32,
add as f32,
)
@ -222,7 +224,7 @@ impl BackendStorage for MetalStorage {
&self.buffer,
layout.stride(),
layout.start_offset() * dtype.size_in_bytes(),
&mut buffer,
&buffer,
mul as f32,
add as f32,
)
@ -246,8 +248,9 @@ impl BackendStorage for MetalStorage {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
assert!(sum_dims.len() == 1);
assert!(sum_dims[0] == layout.shape().rank() - 1);
assert!(layout.is_contiguous());
assert!(layout.start_offset() == 0);
assert!(layout.stride()[sum_dims[0]] == 1);
// assert!(layout.is_contiguous());
// assert!(layout.start_offset() == 0);
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
@ -282,6 +285,9 @@ impl BackendStorage for MetalStorage {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
if dtype == DType::U32{
todo!("Implement this");
}
let mut buffer = device.new_buffer(dst_el, dtype);
let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_reduce_contiguous(
@ -292,7 +298,8 @@ impl BackendStorage for MetalStorage {
src_el,
dst_el,
&self.buffer,
&mut buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
@ -327,7 +334,7 @@ impl BackendStorage for MetalStorage {
kernel_name,
el_count,
&self.buffer,
&mut buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
@ -346,7 +353,7 @@ impl BackendStorage for MetalStorage {
&self.buffer,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&mut buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
@ -359,23 +366,13 @@ impl BackendStorage for MetalStorage {
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype);
let metal = self.device.clone();
let mut cloned = buffer.clone();
let inbuffer = self.buffer.clone();
let ldims = layout.dims().to_owned();
let lstride = layout.stride().to_owned();
let loffset = layout.start_offset() * dtype.size_in_bytes();
if layout.is_contiguous() && layout.start_offset() == 0 {
// self.device.queue.exec_async(move || {
let device = metal;
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
@ -413,16 +410,11 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
el_count,
&inbuffer,
&mut cloned,
&self.buffer,
&buffer,
)
.unwrap();
// });
.map_err(MetalError::from)?;
} else {
// self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT,
@ -458,17 +450,15 @@ impl BackendStorage for MetalStorage {
&command_buffer,
&device.kernels,
kernel_name,
&ldims,
&inbuffer,
&lstride,
loffset,
&mut cloned,
layout.dims(),
&self.buffer,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
0,
)
.unwrap();
// });
.map_err(MetalError::from)?;
}
Ok(Self {
buffer,
device: device.clone(),
@ -520,7 +510,7 @@ impl BackendStorage for MetalStorage {
el_count,
&self.buffer,
&rhs.buffer,
&mut buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
@ -549,7 +539,7 @@ impl BackendStorage for MetalStorage {
&rhs.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&mut buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
@ -590,7 +580,7 @@ impl BackendStorage for MetalStorage {
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer,
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
&mut buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self {
@ -700,7 +690,7 @@ impl BackendStorage for MetalStorage {
dim,
&self.buffer,
&ids.buffer,
&mut buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self {
@ -775,26 +765,26 @@ impl BackendStorage for MetalStorage {
mnk: (m, n, k),
})?
};
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
} as u64;
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
} as u64;
// let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
// [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
// [stride] => stride,
// [] => m * k,
// _ => Err(MetalError::MatMulNonContiguous {
// lhs_stride: lhs_stride.to_vec(),
// rhs_stride: rhs_stride.to_vec(),
// mnk: (m, n, k),
// })?,
// } as u64;
// let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
// [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
// [stride] => stride,
// [] => n * k,
// _ => Err(MetalError::MatMulNonContiguous {
// lhs_stride: lhs_stride.to_vec(),
// rhs_stride: rhs_stride.to_vec(),
// mnk: (m, n, k),
// })?,
// } as u64;
let b = b as NSUInteger;
let m = m as NSUInteger;
@ -802,73 +792,72 @@ impl BackendStorage for MetalStorage {
let k = k as NSUInteger;
let left_descriptor = if transpose_left {
MatrixDescriptor::init_single(k, m, m * size, type_id)
MatrixDescriptor::init_multiple(k, m, b, m * size, m * k * size, type_id)
} else {
MatrixDescriptor::init_single(m, k, k * size, type_id)
MatrixDescriptor::init_multiple(m, k, b, k * size, k * m * size, type_id)
};
let right_descriptor = if transpose_right {
MatrixDescriptor::init_single(n, k, k * size, type_id)
MatrixDescriptor::init_multiple(n, k, b, k * size, k * n * size, type_id)
} else {
MatrixDescriptor::init_single(k, n, n * size, type_id)
MatrixDescriptor::init_multiple(k, n, b, n * size, n * k * size, type_id)
};
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let command_buffer = self.device.command_buffer();
for bi in 0..b {
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
(bi * stride_left + lhs_l.start_offset() as u64) * size,
&left_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(
&rhs.buffer,
(bi * stride_right + rhs_l.start_offset() as u64) * size,
&right_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
lhs_l.start_offset() as NSUInteger * size,
&left_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(
&rhs.buffer,
rhs_l.start_offset() as NSUInteger * size,
&right_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer,
bi * m * n * size,
&result_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer,
0,
&result_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
matrix_multiplication.set_batch_size(b);
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
}
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
Ok(Self {
buffer: out_buffer,
@ -900,7 +889,7 @@ impl BackendStorage for MetalStorage {
&self.buffer,
src_l.stride(),
src_l.start_offset() * self.dtype.size_in_bytes(),
&mut dst.buffer,
&dst.buffer,
dst_offset * dst.dtype.size_in_bytes(),
)
.map_err(MetalError::from)?;
@ -909,7 +898,7 @@ impl BackendStorage for MetalStorage {
}
impl MetalStorage {
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
Self {
buffer,
device,
@ -944,14 +933,12 @@ impl BackendDevice for MetalDevice {
let command_queue = device.new_command_queue();
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
let kernels = Arc::new(Kernels::new());
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
let buffers = Arc::new(RwLock::new(FxHashMap::default()));
let buffers = Arc::new(RwLock::new(HashMap::new()));
Ok(Self {
device,
command_queue,
command_buffer,
buffers,
queue,
kernels,
})
}
@ -996,7 +983,7 @@ impl BackendDevice for MetalDevice {
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
};
Ok(Self::Storage {
buffer,
buffer: buffer.into(),
device: self.clone(),
dtype: storage.dtype(),
})

View File

@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
[[example]]
name = "llama_multiprocess"

View File

@ -10,8 +10,7 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../../metal-rs", features = ["mps"] }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -0,0 +1,499 @@
//
// GEMM.metal
// MetalFlashAttention
//
// Created by Philip Turner on 6/23/23.
//
#include <metal_stdlib>
#include "metal_data_type"
#include "metal_simdgroup_event"
#include "metal_simdgroup_matrix_storage"
using namespace metal;
// MARK: - Function Constants
// Dimensions of each matrix.
constant uint M [[function_constant(0)]];
constant uint N [[function_constant(1)]];
constant uint K [[function_constant(2)]];
// Whether each matrix is transposed.
constant bool A_trans [[function_constant(10)]];
constant bool B_trans [[function_constant(11)]];
constant bool D_trans [[function_constant(13)]];
constant uint A_leading_dim = A_trans ? M : K;
constant uint B_leading_dim = B_trans ? K : N;
// Alpha and beta constants from BLAS.
constant float alpha [[function_constant(20)]];
constant float beta [[function_constant(21)]];
constant bool batched [[function_constant(100)]];
constant bool fused_activation [[function_constant(101)]];
constant bool fused_bias [[function_constant(50001)]]; // 102
constant bool use_bias = is_function_constant_defined(fused_bias) ? fused_bias : false;
constant bool use_activation_function = fused_activation && !fused_bias;
constant bool use_activation = use_bias || use_activation_function;
constant bool batched_activation_function = batched && use_activation_function;
constant ushort M_simd [[function_constant(200)]];
constant ushort N_simd [[function_constant(201)]];
constant ushort K_simd [[function_constant(202)]];
// Elide work on the edge when matrix dimension < SRAM block dimension.
constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd);
constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd);
constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd;
constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd;
constant ushort M_splits [[function_constant(210)]];
constant ushort N_splits [[function_constant(211)]];
constant ushort M_group = M_simd * M_splits;
constant ushort N_group = N_simd * N_splits;
constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd);
constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group);
// There is no padding for M reads/writes.
// There is no padding for N reads/writes.
constant ushort K_simd_unpadded = (K % K_simd == 0) ? K_simd : (K % K_simd);
constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8;
constant ushort A_sram_length = (M_simd / 8) * 1;
constant ushort B_sram_length = 1 * (N_simd / 8);
constant ushort A_block_length = M_group * K_simd;
// Threadgroup block must fit entire C accumulator and partial sums.
constant ushort A_sram_offset = 0;
constant ushort B_sram_offset = A_sram_offset + A_sram_length;
constant ushort C_sram_offset = B_sram_offset + B_sram_length;
constant ushort A_block_offset = 0;
constant ushort B_block_offset = A_block_offset + A_block_length;
// MARK: - Utilities
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* A_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
// A_sram[M_simd][8]
return sram + A_sram_offset + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* B_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
// A_sram[8][N_simd]
return sram + B_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* C_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
// C_sram[M_simd][N_simd]
return sram + C_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC void prefetch(threadgroup T *A_block, device T *A,
ushort2 A_tile_src, uint2 A_offset,
threadgroup T *B_block, device T *B,
ushort2 B_tile_src, uint2 B_offset, uint k)
{
A_tile_src.x = min(uint(K_simd), K - k);
B_tile_src.y = min(uint(K_simd), K - k);
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
auto B_src = simdgroup_matrix_storage<T>::apply_offset(B, B_leading_dim, B_offset, B_trans);
// Rounded-up ceiling for the threadgroup block.
const uint K_edge_floor = K - K_simd_unpadded;
const uint K_edge_ceil = K_edge_floor + K_simd_padded;
ushort K_padded;
if (K_edge_floor == K_simd) {
K_padded = K_simd;
} else {
K_padded = min(uint(K_simd), K_edge_ceil - k);
}
ushort2 A_tile_dst(K_padded, A_tile_src.y);
ushort2 B_tile_dst(B_tile_src.x, K_padded);
simdgroup_event events[2];
events[0].async_copy(A_block, A_block_leading_dim, A_tile_dst, A_src, A_leading_dim, A_tile_src, A_trans);
events[1].async_copy(B_block, B_block_leading_dim, B_tile_dst, B_src, B_leading_dim, B_tile_src, B_trans);
simdgroup_event::wait(2, events);
}
// One iteration of the MACC loop, effectively k=8 iterations.
template <typename T>
METAL_FUNC void multiply_accumulate(thread simdgroup_matrix_storage<T> *sram,
const threadgroup T *A_block,
const threadgroup T *B_block,
bool accumulate = true)
{
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
ushort2 origin(0, m);
A_sram(sram, origin)->load(A_block, A_block_leading_dim, origin, A_trans);
}
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, 0);
B_sram(sram, origin)->load(B_block, B_block_leading_dim, origin, B_trans);
}
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
auto A = A_sram(sram, ushort2(0, m));
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto B = B_sram(sram, ushort2(n, 0));
auto C = C_sram(sram, ushort2(n, m));
C->multiply(*A, *B, accumulate);
}
}
}
template <typename T>
METAL_FUNC void partial_store(thread simdgroup_matrix_storage<T> *sram,
threadgroup T *C_block, bool is_k_summation)
{
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, m);
if (is_k_summation) {
C_sram(sram, origin)->store(C_block, N_simd, origin);
} else {
C_sram(sram, origin)->store(C_block, N_group, origin);
}
}
}
}
template <typename T>
METAL_FUNC void partial_accumulate(thread simdgroup_matrix_storage<T> *sram,
threadgroup T *C_block, bool is_k_summation)
{
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, m);
auto B = B_sram(sram, ushort2(n, 0));
if (is_k_summation) {
B->load(C_block, N_simd, origin);
} else {
B->load(C_block, N_group, origin);
}
}
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, m);
auto B = B_sram(sram, ushort2(n, 0));
auto C = C_sram(sram, origin);
if (is_k_summation) {
C->thread_elements()[0] += B->thread_elements()[0];
} else {
float2 C_old = float2(B->thread_elements()[0]);
float2 C_new = float2(C->thread_elements()[0]);
C->thread_elements()[0] = vec<T, 2>(fast::fma(C_old, beta, C_new));
}
}
}
}
template <typename T>
METAL_FUNC void async_access_accumulator(threadgroup T *C_block, device T *C,
uint2 C_offset, bool is_store)
{
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
min(uint(M_group), M - C_offset.y));
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, C_offset);
simdgroup_event event;
if (is_store) {
event.async_copy(C_src, N, C_tile, C_block, N_group, C_tile);
} else {
event.async_copy(C_block, N_group, C_tile, C_src, N, C_tile);
simdgroup_event::wait(1, &event);
}
}
template <typename T>
METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage<T> *sram,
device T *C, bool m_is_edge, bool n_is_edge)
{
const ushort m_start = (m_is_edge) ? M_modulo : 0;
const ushort n_start = (n_is_edge) ? N_modulo : 0;
const ushort m_end = (m_is_edge) ? M_simd : M_modulo;
const ushort n_end = (n_is_edge) ? N_simd : N_modulo;
#pragma clang loop unroll(full)
for (ushort m = m_start; m < m_end; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = n_start; n < n_end; n += 8) {
ushort2 origin(n, m);
C_sram(sram, origin)->store(C, N, origin);
}
}
}
template <typename T>
struct activation_functor {
using function = void(threadgroup T *C,
device void *D,
uint grid_index_in_batch,
uint2 matrix_origin,
ushort2 tile_dimensions,
ushort lane_id);
typedef visible_function_table<function> function_table;
};
// MARK: - Kernels
template <typename T>
void _gemm_impl(device T *A [[buffer(0)]],
device T *B [[buffer(1)]],
device T *C [[buffer(2)]],
device void *D [[buffer(3), function_constant(use_activation)]],
threadgroup T *threadgroup_block [[threadgroup(0)]],
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
typename activation_functor<T>::function_table table [[buffer(11), function_constant(use_activation_function)]],
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort sidx [[simdgroup_index_in_threadgroup]],
ushort lane_id [[thread_index_in_simdgroup]])
{
if (batched) {
// TODO: Re-compute every inner loop iteration for FP64 accumulate.
ulong3 offsets = matrix_offsets[gid.z].xyz;
A = (device T*)((device uchar*)A + offsets[0]);
B = (device T*)((device uchar*)B + offsets[1]);
C = (device T*)((device uchar*)C + offsets[2]);
}
simdgroup_matrix_storage<T> sram[1024];
auto A_block = threadgroup_block + A_block_offset;
auto B_block = threadgroup_block + B_block_offset;
ushort2 sid(sidx % N_splits, sidx / N_splits);
ushort2 offset_in_simd = simdgroup_matrix_storage<T>::offset(lane_id);
uint2 A_offset(0, gid.y * M_group);
uint2 B_offset(gid.x * N_group, 0);
{
uint C_base_offset_x = B_offset.x + sid.x * N_simd;
uint C_base_offset_y = A_offset.y + sid.y * M_simd;
if (C_base_offset_x >= N || C_base_offset_y >= M) {
return;
}
}
ushort2 offset_in_group(sid.x * N_simd + offset_in_simd.x,
sid.y * M_simd + offset_in_simd.y);
if (use_bias) {
if (sidx == 0) {
auto bias = (device T*)D;
if (batched) {
ulong offset = matrix_offsets[gid.z].w;
bias = (device T*)((device uchar*)bias + offset);
}
ushort bias_elements;
if (is_function_constant_defined(D_trans) && D_trans) {
bias += A_offset.y;
bias_elements = min(uint(M_group), M - A_offset.y);
} else {
bias += B_offset.x;
bias_elements = min(uint(N_group), N - B_offset.x);
}
simdgroup_event event;
event.async_copy(threadgroup_block, bias, bias_elements);
simdgroup_event::wait(1, &event);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (is_function_constant_defined(D_trans) && D_trans) {
auto bias = threadgroup_block + offset_in_group.y;
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
auto D = bias[m];
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto C = C_sram(sram, ushort2(n, m));
*(C->thread_elements()) = D;
}
}
} else {
auto bias = threadgroup_block + offset_in_group.x;
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto D = *(threadgroup vec<T, 2>*)(bias + n);
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
auto C = C_sram(sram, ushort2(n, m));
*(C->thread_elements()) = D;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
ushort2 A_tile_src;
ushort2 B_tile_src;
if (sidx == 0) {
A_tile_src.y = min(uint(M_group), M - A_offset.y);
B_tile_src.x = min(uint(N_group), N - B_offset.x);
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, 0);
}
if (K > K_simd && !use_bias) {
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
*C_sram(sram, ushort2(n, m)) = simdgroup_matrix_storage<T>(0);
}
}
}
for (uint K_floor = 0; K_floor < K; K_floor += K_simd) {
ushort2 A_block_offset(offset_in_simd.x, offset_in_group.y);
ushort2 B_block_offset(offset_in_group.x, offset_in_simd.y);
auto A_block_src = simdgroup_matrix_storage<T>::apply_offset(A_block, A_block_leading_dim, A_block_offset, A_trans);
auto B_block_src = simdgroup_matrix_storage<T>::apply_offset(B_block, B_block_leading_dim, B_block_offset, B_trans);
threadgroup_barrier(mem_flags::mem_threadgroup);
#pragma clang loop unroll(full)
for (ushort k = 0; k < K_simd_padded; k += 8) {
bool accumulate = use_bias || !(K <= K_simd && k == 0);
multiply_accumulate(sram, A_block_src, B_block_src, accumulate);
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
}
if (K_floor + K_simd < K) {
#pragma clang loop unroll(full)
for (ushort k = K_simd_padded; k < K_simd; k += 8) {
multiply_accumulate(sram, A_block_src, B_block_src);
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sidx == 0) {
uint K_next = K_floor + K_simd;
A_offset.x = K_next;
B_offset.y = K_next;
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, K_next);
}
}
}
if (alpha != 1) {
#pragma clang loop unroll(full)
for (int m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (int n = 0; n < N_padded; n += 8) {
C_sram(sram, ushort2(n, m))->thread_elements()[0] *= alpha;
}
}
}
uint2 C_offset(B_offset.x, A_offset.y);
ushort2 C_block_offset = offset_in_group.xy;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (beta != 0) {
if (sidx == 0) {
async_access_accumulator(threadgroup_block, C, C_offset, false);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
partial_accumulate(sram, C_block, false);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (use_activation_function) {
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
partial_store(sram, C_block, false);
simdgroup_barrier(mem_flags::mem_threadgroup);
uint grid_index_in_batch = (batched ? gid.z : 0);
uint2 matrix_origin = C_offset + uint2(C_block_offset);
matrix_origin &= ~7;
ushort2 tile_dimensions(min(uint(N_group), N - matrix_origin.x),
min(uint(M_group), M - matrix_origin.y));
uint function_index = 0;
if (batched_activation_function) {
function_index = activation_function_offsets[gid.z];
}
table[function_index](C_block, D, grid_index_in_batch, matrix_origin, tile_dimensions, lane_id);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sidx == 0) {
async_access_accumulator(threadgroup_block, C, C_offset, true);
}
return;
} else if ((M % 8 != 0) || (N % 8 != 0)) {
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
partial_store(sram, C_block, false);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sidx == 0) {
async_access_accumulator(threadgroup_block, C, C_offset, true);
}
} else {
uint2 matrix_origin = C_offset + uint2(C_block_offset);
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, matrix_origin);
store_accumulator(sram, C_src, false, false);
const uint M_edge_floor = M - M % M_simd;
const uint N_edge_floor = N - N % N_simd;
if (matrix_origin.y < M_edge_floor) {
store_accumulator(sram, C_src, true, false);
}
if (matrix_origin.x < N_edge_floor) {
store_accumulator(sram, C_src, false, true);
if (matrix_origin.y < M_edge_floor) {
store_accumulator(sram, C_src, true, true);
}
}
}
}
kernel void hgemm(device half *A [[buffer(0)]],
device half *B [[buffer(1)]],
device half *C [[buffer(2)]],
device void *D [[buffer(3), function_constant(use_activation)]],
threadgroup half *threadgroup_block [[threadgroup(0)]],
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
typename activation_functor<half>::function_table table [[buffer(11), function_constant(use_activation_function)]],
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort sidx [[simdgroup_index_in_threadgroup]],
ushort lane_id [[thread_index_in_simdgroup]])
{
_gemm_impl<half>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
}
kernel void sgemm(device float *A [[buffer(0)]],
device float *B [[buffer(1)]],
device float *C [[buffer(2)]],
device void *D [[buffer(3), function_constant(use_activation)]],
threadgroup float *threadgroup_block [[threadgroup(0)]],
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
typename activation_functor<float>::function_table table [[buffer(11), function_constant(use_activation_function)]],
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort sidx [[simdgroup_index_in_threadgroup]],
ushort lane_id [[thread_index_in_simdgroup]])
{
_gemm_impl<float>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
}

View File

@ -14,6 +14,7 @@ const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
const FLASH: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
@ -106,6 +107,7 @@ pub enum Source {
Ternary,
Cast,
Reduce,
Gemm,
}
macro_rules! ops{
@ -229,6 +231,7 @@ impl Kernels {
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Gemm => ""
}
}
@ -241,10 +244,17 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
let source_content = self.get_library_source(source);
let lib = device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
let lib = match source {
Source::Gemm => device
.new_library_with_data(FLASH)
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?,
_souce => {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
}
};
libraries.insert(source, lib.clone());
Ok(lib)
}
@ -291,6 +301,160 @@ impl Kernels {
}
}
enum Gemm{
Float,
Half,
}
impl Gemm{
fn size_of_dtype(&self) -> usize{
match self{
Gemm::Float => 4,
Gemm::Half => 2,
}
}
fn name(&self) -> &'static str{
match self{
Gemm::Float => "sgemm",
Gemm::Half => "hgemm",
}
}
}
pub fn call_gemm(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: Gemm,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Gemm, name.name())?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let config = gemm_config(&p);
let m_group = config.m_group;
let n_group = config.n_group;
let k_simd = config.k_simd.value;
let m_splits = config.m_splits.value;
let n_splits = config.n_splits.value;
let size_of_dtype = name.size_of_dtype();
let a_block_bytes = m_group * k_simd * size_of_dtype;
let b_block_bytes = k_simd * n_group * size_of_dtype;
let c_block_bytes = m_group * n_group * size_of_dtype;
let mut thread_group_memory_length = a_block_bytes + b_block_bytes;
if p.m % 8 > 0 && p.n % 8 > 0 {
thread_group_memory_length = max(thread_group_memory_length, c_block_bytes);
}
if p.fused_bias {
let d_block_bytes = if p.d_trans {
m_group * T::SIZE
} else {
n_group * T::SIZE
};
thread_group_memory_length = max(thread_group_memory_length, d_block_bytes);
}
let grid_size = MTLSize::new(
utils::ceil_divide(p.n, n_group)?,
utils::ceil_divide(p.m, m_group)?,
1,
);
let group_size = MTLSize::new((32 * m_splits * n_splits) as NSUInteger, 1, 1);
let mut flags = 0;
if p.batched {
flags |= 0x1;
}
if p.fused_activation {
flags |= 0x2;
}
if p.fused_bias {
flags |= 0x4;
}
let constant_values = config.create_function_constant_values();
let function = lib.get_function(T::FN_NAME, Some(constant_values))?;
encoder
.set_threadgroup_memory_length(0, memory_length);
encoder.use_resources(&[a.buffer(), b.buffer()], MTLResourceUsage::Read);
encoder.use_resource(c.buffer(), MTLResourceUsage::Write);
if let Some(d) = d {
encoder.use_resource(d.buffer(), MTLResourceUsage::Read);
}
encoder.set_buffers(
0,
&[Some(a.buffer()), Some(b.buffer()), Some(c.buffer())],
&[0; 3],
);
if let Some(d) = d {
encoder.set_buffer(3, Some(d.buffer()), 0);
}
let mut grid_z = 1;
if pipeline.flags() & 0x1 > 0 {
panic!("Batched gemm not implemented yet");
// let batch_dimensions_a = tensors.a.shape.dropLast(2);
// let batch_dimensions_b = tensors.b.shape.dropLast(2);
// let batch_dimensions_c = tensors.c.shape.dropLast(2);
// assert!(batch_dimensions_a.iter().product() > 0);
// assert!(
// batch_dimensions_b.iter().product() == 1 ||
// batch_dimensions_b == batch_dimensions_a);
// assert!(batch_dimensions_a == batch_dimensions_c);
// grid_z = batch_dimensions_a.iter().product();
//
// if let Some(batch_dimensions_d) = tensors.d { .shape.dropLast(1)
// assert!(
// batch_dimensions_d.reduce(1, *) == 1 ||
// batch_dimensions_d == batch_dimensions_a);
// }
//
// // Mixed precision will cause undefined behavior.
// let element_size = mem::size_of::<T>();
// let byte_stride = |shape: Vec<u64>| -> u32 {
// let rank = shape.len();
// let mut output = element_size * shape[rank - 2] * shape[rank - 1];
// if shape.dropLast(2).product() == 1 {
// output = 0
// }
// output
// } as u32;
// let byte_stride_a = byte_stride(tensors.a.shape);
// let byte_stride_b = byte_stride(tensors.b.shape);
// let byte_stride_c = byte_stride(tensors.c.shape);
//
// var byteStrideD = 0
// if let shapeD = tensors.d?.shape {
// let rank = shapeD.count
// byteStrideD = element_size * shapeD[rank - 1]
// if shapeD.dropLast(1).reduce(1, *) == 1 {
// byteStrideD = 0
// }
// }
// withUnsafeTemporaryAllocation(
// of: SIMD4<UInt64>.self, capacity: gridZ
// ) { buffer in
// for i in 0..<buffer.count {
// buffer[i] = SIMD4(
// UInt64(truncatingIfNeeded: i * byte_stride_a),
// UInt64(truncatingIfNeeded: i * byte_stride_b),
// UInt64(truncatingIfNeeded: i * byte_stride_c),
// UInt64(truncatingIfNeeded: i * byteStrideD))
// }
//
// let bufferLength = buffer.count * MemoryLayout<SIMD3<UInt64>>.stride
// assert(MemoryLayout<SIMD3<UInt64>>.stride == 8 * 4)
// encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
// }
Ok(())
}
pub fn call_unary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@ -298,11 +462,8 @@ pub fn call_unary_contiguous(
kernel_name: unary::contiguous::Kernel,
length: usize,
input: &Buffer,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -323,7 +484,7 @@ pub fn call_unary_strided(
input: &Buffer,
strides: &[usize],
offset: usize,
output: &mut Buffer,
output: &Buffer,
output_offset: usize,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
@ -361,7 +522,7 @@ pub fn call_binary_contiguous(
length: usize,
left: &Buffer,
right: &Buffer,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
@ -389,7 +550,7 @@ pub fn call_binary_strided(
right_input: &Buffer,
right_strides: &[usize],
right_offset: usize,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
@ -428,7 +589,7 @@ pub fn call_cast_contiguous(
kernel_name: &'static str,
length: usize,
input: &Buffer,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
@ -453,7 +614,7 @@ pub fn call_cast_strided(
input: &Buffer,
input_strides: &[usize],
input_offset: usize,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
@ -484,7 +645,8 @@ pub fn call_reduce_contiguous(
length: usize,
out_length: usize,
input: &Buffer,
output: &mut Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
@ -492,7 +654,10 @@ pub fn call_reduce_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, input, output));
set_params!(
encoder,
(length, elements_to_sum, (input, input_offset), output)
);
let thread_group_count = MTLSize {
width: out_length as u64,
@ -525,7 +690,7 @@ pub fn call_last_softmax(
length: usize,
elements_to_sum: usize,
input: &Buffer,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
@ -566,7 +731,7 @@ pub fn call_affine(
name: &'static str,
size: usize,
input: &Buffer,
output: &mut Buffer,
output: &Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
@ -592,7 +757,7 @@ pub fn call_affine_strided(
input: &Buffer,
input_stride: &[usize],
input_offset: usize,
output: &mut Buffer,
output: &Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
@ -634,7 +799,7 @@ pub fn call_where_cond_strided(
(left_stride, left_offset): (&[usize], usize),
right: &Buffer,
(right_stride, right_offset): (&[usize], usize),
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
@ -677,7 +842,7 @@ pub fn call_index_select(
dim: usize,
input: &Buffer,
ids: &Buffer,
output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
@ -744,7 +909,7 @@ mod tests {
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let mut output = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_unary_contiguous(
&device,
command_buffer,
@ -752,7 +917,7 @@ mod tests {
name,
v.len(),
&input,
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
@ -768,7 +933,7 @@ mod tests {
let options = MTLResourceOptions::StorageModeManaged;
let left = new_buffer(&device, x);
let right = new_buffer(&device, y);
let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
call_binary_contiguous(
&device,
command_buffer,
@ -777,7 +942,7 @@ mod tests {
x.len(),
&left,
&right,
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
@ -796,7 +961,7 @@ mod tests {
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let mut output = new_buffer(&device, v);
let output = new_buffer(&device, v);
let kernels = Kernels::new();
call_unary_strided(
&device,
@ -807,7 +972,7 @@ mod tests {
&input,
strides,
offset,
&mut output,
&output,
0,
)
.unwrap();
@ -894,7 +1059,7 @@ mod tests {
#[test]
fn cos_strided_random() {
let v: Vec<_> = (0..10_000).map(|i| rand::random::<f32>()).collect();
let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect();
let shape = vec![5_000, 2];
let strides = vec![1, 5_000];
let offset = 0;
@ -936,7 +1101,7 @@ mod tests {
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let mut output = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_cast_contiguous(
&device,
@ -945,7 +1110,7 @@ mod tests {
name,
v.len(),
&input,
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
@ -975,7 +1140,7 @@ mod tests {
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let mut output = new_buffer(&device, v);
let output = new_buffer(&device, v);
let size = v.len();
@ -986,7 +1151,7 @@ mod tests {
"affine_float",
size,
&input,
&mut output,
&output,
mul as f32,
add as f32,
)
@ -997,7 +1162,7 @@ mod tests {
output.read_to_vec::<T>(v.len())
}
fn run_affine_strided<T: Clone>(
fn _run_affine_strided<T: Clone>(
v: &[T],
shape: &[usize],
strides: &[usize],
@ -1010,9 +1175,7 @@ mod tests {
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let mut output = new_buffer(&device, v);
let size = v.len();
let output = new_buffer(&device, v);
call_affine_strided(
&device,
@ -1023,7 +1186,7 @@ mod tests {
&input,
strides,
0,
&mut output,
&output,
mul as f32,
add as f32,
)
@ -1108,7 +1271,7 @@ mod tests {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let dst_el = ids.len() * left_size * right_size;
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
let kernels = Kernels::new();
call_index_select(
@ -1121,7 +1284,7 @@ mod tests {
dim,
&embeddings_buffer,
&ids_buffer,
&mut dst_buffer,
&dst_buffer,
)
.unwrap();
@ -1218,7 +1381,7 @@ mod tests {
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
let mut output =
let output =
device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
call_reduce_contiguous(
&device,
@ -1228,7 +1391,8 @@ mod tests {
v.len(),
out_length,
&input,
&mut output,
0,
&output,
)
.unwrap();
command_buffer.commit();
@ -1247,7 +1411,7 @@ mod tests {
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let mut output = new_buffer(&device, v);
let output = new_buffer(&device, v);
call_last_softmax(
&device,
command_buffer,
@ -1256,7 +1420,7 @@ mod tests {
v.len(),
last_dim,
&input,
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
@ -1343,7 +1507,7 @@ mod tests {
options,
);
let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_where_cond_strided(
&device,
command_buffer,
@ -1356,7 +1520,7 @@ mod tests {
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
&mut output,
&output,
)
.unwrap();
command_buffer.commit();
@ -1386,4 +1550,50 @@ mod tests {
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
#[test]
fn flash_gemm() {
let b = 2;
let m = 3;
let n = 2;
let k = 4;
let left: Vec<_> = (0..b*m*k).map(|f| f as f32).collect();
let right: Vec<_> = (0..b*k*n).map(|f| f as f32).collect();
let out: Vec<_> = (0..b*m*n).map(|f| f as f32).collect();
let dims = 3;
let left_shape= vec![b, m, k];
let right_shape= vec![b, k, n];
let out_shape = vec![b, m , n];
let left_stride = vec![m * k, k, 1];
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let left = device.new_buffer_with_data(
left.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(left.as_slice()) as u64,
options,
);
let right = device.new_buffer_with_data(
right.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(right.as_slice()) as u64,
options,
);
let out = device.new_buffer(
(out.len() * std::mem::size_of::<f32>()) as NSUInteger,
options,
);
command_buffer.commit();
command_buffer.wait_until_completed();
let results = out.read_to_vec::<f32>(b * m * n);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
}

View File

@ -16,7 +16,7 @@ METAL_FUNC uint get_strided_index(
return strided_i;
}
constant int THREADGROUP_SIZE = 256;
constant int THREADGROUP_SIZE = 1024;
# define REDUCE(FN, NAME, TYPENAME) \
kernel void NAME( \

View File

@ -19,6 +19,7 @@ num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -29,3 +30,4 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels"]

View File

@ -201,6 +201,37 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::{BackendStorage};
let device = storage.device();
let command_buffer = device.command_buffer();
let kernels = device.kernels();
let name = "softmax_float";
assert!(layout.is_contiguous());
assert!(layout.start_offset() == 0);
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype());
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
&kernels,
name,
elem_count,
last_dim,
storage.buffer(),
&mut output,
)
.unwrap();
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
}
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {