diff --git a/candle-metal-kernels/src/GEMM.metal b/candle-metal-kernels/src/GEMM.metal new file mode 100644 index 00000000..9504a191 --- /dev/null +++ b/candle-metal-kernels/src/GEMM.metal @@ -0,0 +1,499 @@ +// +// GEMM.metal +// MetalFlashAttention +// +// Created by Philip Turner on 6/23/23. +// + +#include +#include "metal_data_type" +#include "metal_simdgroup_event" +#include "metal_simdgroup_matrix_storage" +using namespace metal; + +// MARK: - Function Constants + +// Dimensions of each matrix. +constant uint M [[function_constant(0)]]; +constant uint N [[function_constant(1)]]; +constant uint K [[function_constant(2)]]; + +// Whether each matrix is transposed. +constant bool A_trans [[function_constant(10)]]; +constant bool B_trans [[function_constant(11)]]; +constant bool D_trans [[function_constant(13)]]; +constant uint A_leading_dim = A_trans ? M : K; +constant uint B_leading_dim = B_trans ? K : N; + +// Alpha and beta constants from BLAS. +constant float alpha [[function_constant(20)]]; +constant float beta [[function_constant(21)]]; + +constant bool batched [[function_constant(100)]]; +constant bool fused_activation [[function_constant(101)]]; +constant bool fused_bias [[function_constant(50001)]]; // 102 +constant bool use_bias = is_function_constant_defined(fused_bias) ? fused_bias : false; +constant bool use_activation_function = fused_activation && !fused_bias; +constant bool use_activation = use_bias || use_activation_function; +constant bool batched_activation_function = batched && use_activation_function; + +constant ushort M_simd [[function_constant(200)]]; +constant ushort N_simd [[function_constant(201)]]; +constant ushort K_simd [[function_constant(202)]]; + +// Elide work on the edge when matrix dimension < SRAM block dimension. +constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd); +constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd); +constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd; +constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd; + +constant ushort M_splits [[function_constant(210)]]; +constant ushort N_splits [[function_constant(211)]]; + +constant ushort M_group = M_simd * M_splits; +constant ushort N_group = N_simd * N_splits; +constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd); +constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group); + +// There is no padding for M reads/writes. +// There is no padding for N reads/writes. +constant ushort K_simd_unpadded = (K % K_simd == 0) ? K_simd : (K % K_simd); +constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8; + +constant ushort A_sram_length = (M_simd / 8) * 1; +constant ushort B_sram_length = 1 * (N_simd / 8); +constant ushort A_block_length = M_group * K_simd; + +// Threadgroup block must fit entire C accumulator and partial sums. +constant ushort A_sram_offset = 0; +constant ushort B_sram_offset = A_sram_offset + A_sram_length; +constant ushort C_sram_offset = B_sram_offset + B_sram_length; +constant ushort A_block_offset = 0; +constant ushort B_block_offset = A_block_offset + A_block_length; + +// MARK: - Utilities + +template +METAL_FUNC thread simdgroup_matrix_storage* A_sram(thread simdgroup_matrix_storage *sram, ushort2 matrix_origin) { + // A_sram[M_simd][8] + return sram + A_sram_offset + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8); +} + +template +METAL_FUNC thread simdgroup_matrix_storage* B_sram(thread simdgroup_matrix_storage *sram, ushort2 matrix_origin) { + // A_sram[8][N_simd] + return sram + B_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8); +} + +template +METAL_FUNC thread simdgroup_matrix_storage* C_sram(thread simdgroup_matrix_storage *sram, ushort2 matrix_origin) { + // C_sram[M_simd][N_simd] + return sram + C_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8); +} + +template +METAL_FUNC void prefetch(threadgroup T *A_block, device T *A, + ushort2 A_tile_src, uint2 A_offset, + threadgroup T *B_block, device T *B, + ushort2 B_tile_src, uint2 B_offset, uint k) +{ + A_tile_src.x = min(uint(K_simd), K - k); + B_tile_src.y = min(uint(K_simd), K - k); + auto A_src = simdgroup_matrix_storage::apply_offset(A, A_leading_dim, A_offset, A_trans); + auto B_src = simdgroup_matrix_storage::apply_offset(B, B_leading_dim, B_offset, B_trans); + + // Rounded-up ceiling for the threadgroup block. + const uint K_edge_floor = K - K_simd_unpadded; + const uint K_edge_ceil = K_edge_floor + K_simd_padded; + ushort K_padded; + if (K_edge_floor == K_simd) { + K_padded = K_simd; + } else { + K_padded = min(uint(K_simd), K_edge_ceil - k); + } + ushort2 A_tile_dst(K_padded, A_tile_src.y); + ushort2 B_tile_dst(B_tile_src.x, K_padded); + + simdgroup_event events[2]; + events[0].async_copy(A_block, A_block_leading_dim, A_tile_dst, A_src, A_leading_dim, A_tile_src, A_trans); + events[1].async_copy(B_block, B_block_leading_dim, B_tile_dst, B_src, B_leading_dim, B_tile_src, B_trans); + simdgroup_event::wait(2, events); +} + +// One iteration of the MACC loop, effectively k=8 iterations. +template +METAL_FUNC void multiply_accumulate(thread simdgroup_matrix_storage *sram, + const threadgroup T *A_block, + const threadgroup T *B_block, + bool accumulate = true) +{ +#pragma clang loop unroll(full) + for (ushort m = 0; m < M_padded; m += 8) { + ushort2 origin(0, m); + A_sram(sram, origin)->load(A_block, A_block_leading_dim, origin, A_trans); + } +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_padded; n += 8) { + ushort2 origin(n, 0); + B_sram(sram, origin)->load(B_block, B_block_leading_dim, origin, B_trans); + } +#pragma clang loop unroll(full) + for (ushort m = 0; m < M_padded; m += 8) { + auto A = A_sram(sram, ushort2(0, m)); +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_padded; n += 8) { + auto B = B_sram(sram, ushort2(n, 0)); + auto C = C_sram(sram, ushort2(n, m)); + C->multiply(*A, *B, accumulate); + } + } +} + +template +METAL_FUNC void partial_store(thread simdgroup_matrix_storage *sram, + threadgroup T *C_block, bool is_k_summation) +{ +#pragma clang loop unroll(full) + for (ushort m = 0; m < M_padded; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_padded; n += 8) { + ushort2 origin(n, m); + if (is_k_summation) { + C_sram(sram, origin)->store(C_block, N_simd, origin); + } else { + C_sram(sram, origin)->store(C_block, N_group, origin); + } + } + } +} + +template +METAL_FUNC void partial_accumulate(thread simdgroup_matrix_storage *sram, + threadgroup T *C_block, bool is_k_summation) +{ +#pragma clang loop unroll(full) + for (ushort m = 0; m < M_padded; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_padded; n += 8) { + ushort2 origin(n, m); + auto B = B_sram(sram, ushort2(n, 0)); + if (is_k_summation) { + B->load(C_block, N_simd, origin); + } else { + B->load(C_block, N_group, origin); + } + } +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_padded; n += 8) { + ushort2 origin(n, m); + auto B = B_sram(sram, ushort2(n, 0)); + auto C = C_sram(sram, origin); + if (is_k_summation) { + C->thread_elements()[0] += B->thread_elements()[0]; + } else { + float2 C_old = float2(B->thread_elements()[0]); + float2 C_new = float2(C->thread_elements()[0]); + C->thread_elements()[0] = vec(fast::fma(C_old, beta, C_new)); + } + } + } +} + +template +METAL_FUNC void async_access_accumulator(threadgroup T *C_block, device T *C, + uint2 C_offset, bool is_store) +{ + ushort2 C_tile(min(uint(N_group), N - C_offset.x), + min(uint(M_group), M - C_offset.y)); + auto C_src = simdgroup_matrix_storage::apply_offset(C, N, C_offset); + + simdgroup_event event; + if (is_store) { + event.async_copy(C_src, N, C_tile, C_block, N_group, C_tile); + } else { + event.async_copy(C_block, N_group, C_tile, C_src, N, C_tile); + simdgroup_event::wait(1, &event); + } +} + +template +METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage *sram, + device T *C, bool m_is_edge, bool n_is_edge) +{ + const ushort m_start = (m_is_edge) ? M_modulo : 0; + const ushort n_start = (n_is_edge) ? N_modulo : 0; + const ushort m_end = (m_is_edge) ? M_simd : M_modulo; + const ushort n_end = (n_is_edge) ? N_simd : N_modulo; + +#pragma clang loop unroll(full) + for (ushort m = m_start; m < m_end; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = n_start; n < n_end; n += 8) { + ushort2 origin(n, m); + C_sram(sram, origin)->store(C, N, origin); + } + } +} + +template +struct activation_functor { + using function = void(threadgroup T *C, + device void *D, + uint grid_index_in_batch, + uint2 matrix_origin, + ushort2 tile_dimensions, + ushort lane_id); + + typedef visible_function_table function_table; +}; + +// MARK: - Kernels + +template +void _gemm_impl(device T *A [[buffer(0)]], + device T *B [[buffer(1)]], + device T *C [[buffer(2)]], + device void *D [[buffer(3), function_constant(use_activation)]], + + threadgroup T *threadgroup_block [[threadgroup(0)]], + constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]], + typename activation_functor::function_table table [[buffer(11), function_constant(use_activation_function)]], + constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]], + + uint3 gid [[threadgroup_position_in_grid]], + ushort sidx [[simdgroup_index_in_threadgroup]], + ushort lane_id [[thread_index_in_simdgroup]]) +{ + if (batched) { + // TODO: Re-compute every inner loop iteration for FP64 accumulate. + ulong3 offsets = matrix_offsets[gid.z].xyz; + A = (device T*)((device uchar*)A + offsets[0]); + B = (device T*)((device uchar*)B + offsets[1]); + C = (device T*)((device uchar*)C + offsets[2]); + } + + simdgroup_matrix_storage sram[1024]; + auto A_block = threadgroup_block + A_block_offset; + auto B_block = threadgroup_block + B_block_offset; + ushort2 sid(sidx % N_splits, sidx / N_splits); + ushort2 offset_in_simd = simdgroup_matrix_storage::offset(lane_id); + + uint2 A_offset(0, gid.y * M_group); + uint2 B_offset(gid.x * N_group, 0); + { + uint C_base_offset_x = B_offset.x + sid.x * N_simd; + uint C_base_offset_y = A_offset.y + sid.y * M_simd; + if (C_base_offset_x >= N || C_base_offset_y >= M) { + return; + } + } + + ushort2 offset_in_group(sid.x * N_simd + offset_in_simd.x, + sid.y * M_simd + offset_in_simd.y); + + if (use_bias) { + if (sidx == 0) { + auto bias = (device T*)D; + if (batched) { + ulong offset = matrix_offsets[gid.z].w; + bias = (device T*)((device uchar*)bias + offset); + } + + ushort bias_elements; + if (is_function_constant_defined(D_trans) && D_trans) { + bias += A_offset.y; + bias_elements = min(uint(M_group), M - A_offset.y); + } else { + bias += B_offset.x; + bias_elements = min(uint(N_group), N - B_offset.x); + } + + simdgroup_event event; + event.async_copy(threadgroup_block, bias, bias_elements); + simdgroup_event::wait(1, &event); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (is_function_constant_defined(D_trans) && D_trans) { + auto bias = threadgroup_block + offset_in_group.y; +#pragma clang loop unroll(full) + for (ushort m = 0; m < M_padded; m += 8) { + auto D = bias[m]; +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_padded; n += 8) { + auto C = C_sram(sram, ushort2(n, m)); + *(C->thread_elements()) = D; + } + } + } else { + auto bias = threadgroup_block + offset_in_group.x; +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_padded; n += 8) { + auto D = *(threadgroup vec*)(bias + n); +#pragma clang loop unroll(full) + for (ushort m = 0; m < M_padded; m += 8) { + auto C = C_sram(sram, ushort2(n, m)); + *(C->thread_elements()) = D; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + ushort2 A_tile_src; + ushort2 B_tile_src; + if (sidx == 0) { + A_tile_src.y = min(uint(M_group), M - A_offset.y); + B_tile_src.x = min(uint(N_group), N - B_offset.x); + prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, 0); + } + + if (K > K_simd && !use_bias) { +#pragma clang loop unroll(full) + for (ushort m = 0; m < M_padded; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_padded; n += 8) { + *C_sram(sram, ushort2(n, m)) = simdgroup_matrix_storage(0); + } + } + } + + for (uint K_floor = 0; K_floor < K; K_floor += K_simd) { + ushort2 A_block_offset(offset_in_simd.x, offset_in_group.y); + ushort2 B_block_offset(offset_in_group.x, offset_in_simd.y); + auto A_block_src = simdgroup_matrix_storage::apply_offset(A_block, A_block_leading_dim, A_block_offset, A_trans); + auto B_block_src = simdgroup_matrix_storage::apply_offset(B_block, B_block_leading_dim, B_block_offset, B_trans); + threadgroup_barrier(mem_flags::mem_threadgroup); + +#pragma clang loop unroll(full) + for (ushort k = 0; k < K_simd_padded; k += 8) { + bool accumulate = use_bias || !(K <= K_simd && k == 0); + multiply_accumulate(sram, A_block_src, B_block_src, accumulate); + A_block_src += A_trans ? 8 * A_block_leading_dim : 8; + B_block_src += B_trans ? 8 : 8 * B_block_leading_dim; + } + + if (K_floor + K_simd < K) { +#pragma clang loop unroll(full) + for (ushort k = K_simd_padded; k < K_simd; k += 8) { + multiply_accumulate(sram, A_block_src, B_block_src); + A_block_src += A_trans ? 8 * A_block_leading_dim : 8; + B_block_src += B_trans ? 8 : 8 * B_block_leading_dim; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sidx == 0) { + uint K_next = K_floor + K_simd; + A_offset.x = K_next; + B_offset.y = K_next; + prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, K_next); + } + } + } + + if (alpha != 1) { +#pragma clang loop unroll(full) + for (int m = 0; m < M_padded; m += 8) { +#pragma clang loop unroll(full) + for (int n = 0; n < N_padded; n += 8) { + C_sram(sram, ushort2(n, m))->thread_elements()[0] *= alpha; + } + } + } + + uint2 C_offset(B_offset.x, A_offset.y); + ushort2 C_block_offset = offset_in_group.xy; + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (beta != 0) { + if (sidx == 0) { + async_access_accumulator(threadgroup_block, C, C_offset, false); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto C_block = simdgroup_matrix_storage::apply_offset(threadgroup_block, N_group, C_block_offset); + partial_accumulate(sram, C_block, false); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (use_activation_function) { + auto C_block = simdgroup_matrix_storage::apply_offset(threadgroup_block, N_group, C_block_offset); + partial_store(sram, C_block, false); + simdgroup_barrier(mem_flags::mem_threadgroup); + + uint grid_index_in_batch = (batched ? gid.z : 0); + uint2 matrix_origin = C_offset + uint2(C_block_offset); + matrix_origin &= ~7; + ushort2 tile_dimensions(min(uint(N_group), N - matrix_origin.x), + min(uint(M_group), M - matrix_origin.y)); + uint function_index = 0; + if (batched_activation_function) { + function_index = activation_function_offsets[gid.z]; + } + table[function_index](C_block, D, grid_index_in_batch, matrix_origin, tile_dimensions, lane_id); + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sidx == 0) { + async_access_accumulator(threadgroup_block, C, C_offset, true); + } + return; + } else if ((M % 8 != 0) || (N % 8 != 0)) { + auto C_block = simdgroup_matrix_storage::apply_offset(threadgroup_block, N_group, C_block_offset); + partial_store(sram, C_block, false); + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sidx == 0) { + async_access_accumulator(threadgroup_block, C, C_offset, true); + } + } else { + uint2 matrix_origin = C_offset + uint2(C_block_offset); + auto C_src = simdgroup_matrix_storage::apply_offset(C, N, matrix_origin); + store_accumulator(sram, C_src, false, false); + + const uint M_edge_floor = M - M % M_simd; + const uint N_edge_floor = N - N % N_simd; + if (matrix_origin.y < M_edge_floor) { + store_accumulator(sram, C_src, true, false); + } + if (matrix_origin.x < N_edge_floor) { + store_accumulator(sram, C_src, false, true); + if (matrix_origin.y < M_edge_floor) { + store_accumulator(sram, C_src, true, true); + } + } + } +} + +kernel void hgemm(device half *A [[buffer(0)]], + device half *B [[buffer(1)]], + device half *C [[buffer(2)]], + device void *D [[buffer(3), function_constant(use_activation)]], + + threadgroup half *threadgroup_block [[threadgroup(0)]], + constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]], + typename activation_functor::function_table table [[buffer(11), function_constant(use_activation_function)]], + constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]], + + uint3 gid [[threadgroup_position_in_grid]], + ushort sidx [[simdgroup_index_in_threadgroup]], + ushort lane_id [[thread_index_in_simdgroup]]) +{ + _gemm_impl(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id); +} + +kernel void sgemm(device float *A [[buffer(0)]], + device float *B [[buffer(1)]], + device float *C [[buffer(2)]], + device void *D [[buffer(3), function_constant(use_activation)]], + + threadgroup float *threadgroup_block [[threadgroup(0)]], + constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]], + typename activation_functor::function_table table [[buffer(11), function_constant(use_activation_function)]], + constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]], + + uint3 gid [[threadgroup_position_in_grid]], + ushort sidx [[simdgroup_index_in_threadgroup]], + ushort lane_id [[thread_index_in_simdgroup]]) +{ + _gemm_impl(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id); +} diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index fcf6930b..b995d1c8 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -14,6 +14,7 @@ const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +const FLASH: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; @@ -106,6 +107,7 @@ pub enum Source { Ternary, Cast, Reduce, + Gemm, } macro_rules! ops{ @@ -229,6 +231,7 @@ impl Kernels { Source::Indexing => INDEXING, Source::Cast => CAST, Source::Reduce => REDUCE, + Source::Gemm => "" } } @@ -241,10 +244,17 @@ impl Kernels { if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { - let source_content = self.get_library_source(source); - let lib = device - .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; + let lib = match source { + Source::Gemm => device + .new_library_with_data(FLASH) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?, + _souce => { + let source_content = self.get_library_source(source); + device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + } + }; libraries.insert(source, lib.clone()); Ok(lib) } @@ -291,6 +301,160 @@ impl Kernels { } } +enum Gemm{ + Float, + Half, +} + +impl Gemm{ + fn size_of_dtype(&self) -> usize{ + match self{ + Gemm::Float => 4, + Gemm::Half => 2, + } + } + fn name(&self) -> &'static str{ + match self{ + Gemm::Float => "sgemm", + Gemm::Half => "hgemm", + } + } +} + +pub fn call_gemm( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: Gemm, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Gemm, name.name())?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let config = gemm_config(&p); + let m_group = config.m_group; + let n_group = config.n_group; + let k_simd = config.k_simd.value; + let m_splits = config.m_splits.value; + let n_splits = config.n_splits.value; + + let size_of_dtype = name.size_of_dtype(); + let a_block_bytes = m_group * k_simd * size_of_dtype; + let b_block_bytes = k_simd * n_group * size_of_dtype; + let c_block_bytes = m_group * n_group * size_of_dtype; + let mut thread_group_memory_length = a_block_bytes + b_block_bytes; + + if p.m % 8 > 0 && p.n % 8 > 0 { + thread_group_memory_length = max(thread_group_memory_length, c_block_bytes); + } + if p.fused_bias { + let d_block_bytes = if p.d_trans { + m_group * T::SIZE + } else { + n_group * T::SIZE + }; + thread_group_memory_length = max(thread_group_memory_length, d_block_bytes); + } + + let grid_size = MTLSize::new( + utils::ceil_divide(p.n, n_group)?, + utils::ceil_divide(p.m, m_group)?, + 1, + ); + + let group_size = MTLSize::new((32 * m_splits * n_splits) as NSUInteger, 1, 1); + + let mut flags = 0; + if p.batched { + flags |= 0x1; + } + if p.fused_activation { + flags |= 0x2; + } + if p.fused_bias { + flags |= 0x4; + } + + let constant_values = config.create_function_constant_values(); + let function = lib.get_function(T::FN_NAME, Some(constant_values))?; + encoder + .set_threadgroup_memory_length(0, memory_length); + + encoder.use_resources(&[a.buffer(), b.buffer()], MTLResourceUsage::Read); + encoder.use_resource(c.buffer(), MTLResourceUsage::Write); + + if let Some(d) = d { + encoder.use_resource(d.buffer(), MTLResourceUsage::Read); + } + + encoder.set_buffers( + 0, + &[Some(a.buffer()), Some(b.buffer()), Some(c.buffer())], + &[0; 3], + ); + if let Some(d) = d { + encoder.set_buffer(3, Some(d.buffer()), 0); + } + + let mut grid_z = 1; + if pipeline.flags() & 0x1 > 0 { + panic!("Batched gemm not implemented yet"); + // let batch_dimensions_a = tensors.a.shape.dropLast(2); + // let batch_dimensions_b = tensors.b.shape.dropLast(2); + // let batch_dimensions_c = tensors.c.shape.dropLast(2); + // assert!(batch_dimensions_a.iter().product() > 0); + // assert!( + // batch_dimensions_b.iter().product() == 1 || + // batch_dimensions_b == batch_dimensions_a); + // assert!(batch_dimensions_a == batch_dimensions_c); + // grid_z = batch_dimensions_a.iter().product(); + // + // if let Some(batch_dimensions_d) = tensors.d { .shape.dropLast(1) + // assert!( + // batch_dimensions_d.reduce(1, *) == 1 || + // batch_dimensions_d == batch_dimensions_a); + // } + // + // // Mixed precision will cause undefined behavior. + // let element_size = mem::size_of::(); + // let byte_stride = |shape: Vec| -> u32 { + // let rank = shape.len(); + // let mut output = element_size * shape[rank - 2] * shape[rank - 1]; + // if shape.dropLast(2).product() == 1 { + // output = 0 + // } + // output + // } as u32; + // let byte_stride_a = byte_stride(tensors.a.shape); + // let byte_stride_b = byte_stride(tensors.b.shape); + // let byte_stride_c = byte_stride(tensors.c.shape); + // + // var byteStrideD = 0 + // if let shapeD = tensors.d?.shape { + // let rank = shapeD.count + // byteStrideD = element_size * shapeD[rank - 1] + // if shapeD.dropLast(1).reduce(1, *) == 1 { + // byteStrideD = 0 + // } + // } + // withUnsafeTemporaryAllocation( + // of: SIMD4.self, capacity: gridZ + // ) { buffer in + // for i in 0..>.stride + // assert(MemoryLayout>.stride == 8 * 4) + // encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10) + // } + Ok(()) +} + pub fn call_unary_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -481,7 +645,7 @@ pub fn call_reduce_contiguous( length: usize, out_length: usize, input: &Buffer, - input_offset: usize, + input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; @@ -490,7 +654,10 @@ pub fn call_reduce_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, (input,input_offset), output)); + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); let thread_group_count = MTLSize { width: out_length as u64, @@ -742,7 +909,7 @@ mod tests { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_unary_contiguous( &device, command_buffer, @@ -766,7 +933,7 @@ mod tests { let options = MTLResourceOptions::StorageModeManaged; let left = new_buffer(&device, x); let right = new_buffer(&device, y); - let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + let output = device.new_buffer(std::mem::size_of_val(x) as u64, options); call_binary_contiguous( &device, command_buffer, @@ -794,7 +961,7 @@ mod tests { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); let kernels = Kernels::new(); call_unary_strided( &device, @@ -892,7 +1059,7 @@ mod tests { #[test] fn cos_strided_random() { - let v: Vec<_> = (0..10_000).map(|i| rand::random::()).collect(); + let v: Vec<_> = (0..10_000).map(|_| rand::random::()).collect(); let shape = vec![5_000, 2]; let strides = vec![1, 5_000]; let offset = 0; @@ -934,7 +1101,7 @@ mod tests { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_cast_contiguous( &device, @@ -973,7 +1140,7 @@ mod tests { let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); let size = v.len(); @@ -995,7 +1162,7 @@ mod tests { output.read_to_vec::(v.len()) } - fn run_affine_strided( + fn _run_affine_strided( v: &[T], shape: &[usize], strides: &[usize], @@ -1008,9 +1175,7 @@ mod tests { let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - - let size = v.len(); + let output = new_buffer(&device, v); call_affine_strided( &device, @@ -1106,7 +1271,7 @@ mod tests { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let dst_el = ids.len() * left_size * right_size; - let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); let kernels = Kernels::new(); call_index_select( @@ -1216,7 +1381,7 @@ mod tests { let input = new_buffer(&device, v); let options = MTLResourceOptions::StorageModeManaged; - let mut output = + let output = device.new_buffer((out_length * core::mem::size_of::()) as u64, options); call_reduce_contiguous( &device, @@ -1226,7 +1391,7 @@ mod tests { v.len(), out_length, &input, - 0, + 0, &output, ) .unwrap(); @@ -1246,7 +1411,7 @@ mod tests { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_last_softmax( &device, command_buffer, @@ -1342,7 +1507,7 @@ mod tests { options, ); - let mut output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); call_where_cond_strided( &device, command_buffer, @@ -1385,4 +1550,50 @@ mod tests { ); assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); } + + #[test] + fn flash_gemm() { + let b = 2; + let m = 3; + let n = 2; + let k = 4; + + + let left: Vec<_> = (0..b*m*k).map(|f| f as f32).collect(); + let right: Vec<_> = (0..b*k*n).map(|f| f as f32).collect(); + let out: Vec<_> = (0..b*m*n).map(|f| f as f32).collect(); + + let dims = 3; + let left_shape= vec![b, m, k]; + let right_shape= vec![b, k, n]; + let out_shape = vec![b, m , n]; + + let left_stride = vec![m * k, k, 1]; + + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let left = device.new_buffer_with_data( + left.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(left.as_slice()) as u64, + options, + ); + let right = device.new_buffer_with_data( + right.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(right.as_slice()) as u64, + options, + ); + let out = device.new_buffer( + (out.len() * std::mem::size_of::()) as NSUInteger, + options, + ); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + let results = out.read_to_vec::(b * m * n); + assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); + } } diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib new file mode 100644 index 00000000..dafd1856 Binary files /dev/null and b/candle-metal-kernels/src/libMetalFlashAttention.metallib differ