diff --git a/candle-metal-kernels/build.rs b/candle-metal-kernels/build.rs index 22f279bd..484c2e68 100644 --- a/candle-metal-kernels/build.rs +++ b/candle-metal-kernels/build.rs @@ -1,7 +1,7 @@ use std::process::Command; use std::{env, str}; -const COMPILED_KERNELS: [&str; 1] = ["reduce"]; +const COMPILED_KERNELS: [&str; 3] = ["event", "matrix_storage", "gemm"]; enum Platform { MacOS, diff --git a/candle-metal-kernels/src/ffi.rs b/candle-metal-kernels/src/ffi.rs new file mode 100644 index 00000000..e69de29b diff --git a/candle-metal-kernels/src/kernels/event.metal b/candle-metal-kernels/src/kernels/event.metal new file mode 100644 index 00000000..93f26718 --- /dev/null +++ b/candle-metal-kernels/src/kernels/event.metal @@ -0,0 +1,226 @@ +// -*- Metal -*- +//===-- metal_simdgroup_event ---------------------------------------------===// +// Copyright (c) 2024 Philip Turner. See MIT LICENSE +//===----------------------------------------------------------------------===// + +#ifndef __METAL_SIMDGROUP_EVENT +#define __METAL_SIMDGROUP_EVENT + +// Invoking the generation of LLVM bitcode for async copies. +// +// %struct._simdgroup_event_t = type opaque +// +struct _simdgroup_event_t; + +// Invoking the generation of LLVM bitcode for async copies. +// +// Bitcode: TBD +// +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_1d( + ulong, ulong, threadgroup void *, const device void *, ulong) + __asm("air.simdgroup_async_copy_1d.p3i8.p1i8"); + +// Invoking the generation of LLVM bitcode for async copies. +// +// Bitcode: TBD +// +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_1d( + ulong, ulong, device void *, const threadgroup void *, ulong) + __asm("air.simdgroup_async_copy_1d.p1i8.p3i8"); + +// Invoking the generation of LLVM bitcode for async copies. +// +// ; Function Attrs: argmemonly convergent nounwind +// declare %struct._simdgroup_event_t* +// @air.simdgroup_async_copy_2d.p3i8.p1i8( +// i64, i64, +// i8 addrspace(3)* nocapture writeonly, i64, i64, <2 x i64>, +// i8 addrspace(1)* nocapture readonly, i64, i64, <2 x i64>, +// <2 x i64>, i32) +// local_unnamed_addr #4 +// +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_2d( + ulong, ulong, + threadgroup void *, ulong, ulong, ulong2, + const device void *, ulong, ulong, ulong2, + long2, int) + __asm("air.simdgroup_async_copy_2d.p3i8.p1i8"); + +// Invoking the generation of LLVM bitcode for async copies. +// +// ; Function Attrs: argmemonly convergent nounwind +// declare %struct._simdgroup_event_t* +// @air.simdgroup_async_copy_2d.p1i8.p3i8( +// i64, i64, +// i8 addrspace(1)* nocapture writeonly, i64, i64, <2 x i64>, +// i8 addrspace(3)* nocapture readonly, i64, i64, <2 x i64>, +// <2 x i64>, i32) +// local_unnamed_addr #4 +// +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_2d( + ulong, ulong, + device void *, ulong, ulong, ulong2, + const threadgroup void *, ulong, ulong, ulong2, + long2, int) + __asm("air.simdgroup_async_copy_2d.p1i8.p3i8"); + +// Invoking the generation of LLVM bitcode for async copies. +// +// ; Function Attrs: convergent nounwind +// declare void +// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture) +// local_unnamed_addr #3 +// +void __metal_wait_simdgroup_events( + int, thread _simdgroup_event_t**) + __asm("air.wait_simdgroup_events"); + +#pragma METAL internals : enable +namespace metal +{ + enum class simdgroup_async_copy_clamp_mode { + clamp_to_zero = 0, + clamp_to_edge = 1 + }; + + struct simdgroup_event { + METAL_FUNC simdgroup_event() thread {} + + template + METAL_FUNC void async_copy( + threadgroup T *dst, + const device T *src, + ulong n_elements + ) thread { + event = __metal_simdgroup_async_copy_1d( + // Description of the data type. + sizeof(T), + alignof(T), + + // Description of the arguments. + reinterpret_cast(dst), + reinterpret_cast(src), + n_elements); + } + + template + METAL_FUNC void async_copy( + device T *dst, + const threadgroup T *src, + ulong n_elements + ) thread { + event = __metal_simdgroup_async_copy_1d( + // Description of the data type. + sizeof(T), + alignof(T), + + // Description of the arguments. + reinterpret_cast(dst), + reinterpret_cast(src), + n_elements); + } + + template + METAL_FUNC void async_copy( + // Description of the destination. + threadgroup T *dst, + ushort dst_elements_per_row, + ushort2 dst_tile_dimensions, + + // Description of the source. + const device T *src, + uint src_elements_per_row, + ushort2 src_tile_dimensions, + + // Other arguments. + bool transpose_matrix = false, + simdgroup_async_copy_clamp_mode clamp_mode = + simdgroup_async_copy_clamp_mode::clamp_to_zero + ) thread { + if (transpose_matrix) { + src_tile_dimensions = src_tile_dimensions.yx; + dst_tile_dimensions = dst_tile_dimensions.yx; + } + event = __metal_simdgroup_async_copy_2d( + // Description of the data type. + sizeof(T), + alignof(T), + + // Description of the destination. + reinterpret_cast(dst), + ushort(dst_elements_per_row), + 1, + ulong2(dst_tile_dimensions), + + // Description of the source. + reinterpret_cast(src), + uint(src_elements_per_row), + 1, + ulong2(src_tile_dimensions), + + // Other arguments. + long2(0), + static_cast(clamp_mode)); + } + + template + METAL_FUNC void async_copy( + // Description of the destination. + device T *dst, + uint dst_elements_per_row, + ushort2 dst_tile_dimensions, + + // Description of the source. + const threadgroup T *src, + ushort src_elements_per_row, + ushort2 src_tile_dimensions, + + // Other arguments. + bool transpose_matrix = false + ) thread { + if (transpose_matrix) { + src_tile_dimensions = src_tile_dimensions.yx; + dst_tile_dimensions = dst_tile_dimensions.yx; + } + event = __metal_simdgroup_async_copy_2d( + // Description of the data type. + sizeof(T), + alignof(T), + + // Description of the destination. + reinterpret_cast(dst), + uint(dst_elements_per_row), + 1, + ulong2(dst_tile_dimensions), + + // Description of the source. + reinterpret_cast(src), + ushort(src_elements_per_row), + 1, + ulong2(src_tile_dimensions), + + // Other arguments. + long2(0), + 0); + } + + METAL_FUNC static void wait(int count, thread simdgroup_event *events) { + __metal_wait_simdgroup_events( + count, reinterpret_cast(events)); + } + + private: + // Invoking the generation of LLVM bitcode for async copies. + // + // %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* } + // + thread _simdgroup_event_t* event; + }; +} // namespace metal +#pragma METAL internals : disable + +#endif diff --git a/candle-metal-kernels/src/kernels/gemm.metal b/candle-metal-kernels/src/kernels/gemm.metal new file mode 100644 index 00000000..06bf6037 --- /dev/null +++ b/candle-metal-kernels/src/kernels/gemm.metal @@ -0,0 +1,538 @@ +// Heavily inspired by the GEMM kernels by Philip Turner: +// https://github.com/philipturner/metal-flash-attention +// This implementation uses generics and specialization to generate kernels for different data types instead of code generation. +#include +#include "event.metal" +#include "matrix_storage.metal" +using namespace metal; + +// Dimensions of each matrix. +// - Limitations to matrix size: +// - 2^32 in each dimension (M/N/K). +// - TODO: Test whether the maximum dimension with correct execution is +// actually 2^16. This will require a testing setup with non-square +// matrices, as 65536^3 is uncomputable. +// - Extending to 2^64 may require changing 'uint' to 'ulong'. There is a +// good chance this will significantly degrade performance, and require +// changing the data type of several variables that process addresses. The +// client is responsible for ensuring correctness and performance with +// matrices spanning several billion elements in one direction. +// - The matrix dimensions must be known at compile time, via function +// constants. Dynamic matrix shapes are beyond the scope of this reference +// implementation. Dynamic shapes cause a non-negligible regression to +// shader execution speed. However, they could minimize a compilation +// latency bottleneck in some use cases. +// - Limitations to batch size: +// - Dictated by how the client modifies the code to implement batching. +// - Dynamic batch shapes would likely not harm performance much. For example, +// someone could enter an array of pointers/memory offsets to different +// matrices in the batch. Each slice of a 3D thread grid could read a +// different pointer from memory, and use that pointer as the A/B/C matrix. +// Another approach is to restrict the input format, so all matrices are +// stored contiguously in memory. Then, the memory offset could be computed +// analytically from matrix size and the Z dimension in a 3D thread grid. +// +// Another note: +// - The rows of the matrix must be contiguous in memory. Supporting strides +// that differ from the actual matrix dimensions should not be difficult, but +// it is out of scope for this reference kernel. +constant uint M [[function_constant(0)]]; +constant uint N [[function_constant(1)]]; +constant uint K [[function_constant(2)]]; + +// Whether each matrix is transposed. +constant bool A_trans [[function_constant(10)]]; +constant bool B_trans [[function_constant(11)]]; + +// Define the memory layout of the matrix block. +constant ushort M_group [[function_constant(200)]]; +constant ushort N_group [[function_constant(201)]]; +constant ushort K_group [[function_constant(202)]]; + +constant bool prefer_async_copy [[function_constant(206)]]; +constant bool ideal_grouping [[function_constant(207)]]; + +constant ushort A_leading_dim = A_trans ? M : K; +constant ushort B_leading_dim = B_trans ? K : N; +constant ushort A_leading_block_dim = A_trans ? M_group : K_group; +constant ushort B_leading_block_dim = B_trans ? K_group : N_group; + +// Thresholds that mark the matrix edge. +constant uint M_edge = M - (M % M_group); +constant uint N_edge = N - (N % N_group); + +// The layout of threads within a SIMD matrix. +// +// 0 0 1 1 8 8 9 9 +// 2 2 3 3 10 10 11 11 +// 4 4 5 5 12 12 13 13 +// 6 6 7 7 14 14 15 15 +// 16 16 17 17 24 24 25 25 +// 18 18 19 19 26 26 27 27 +// 20 20 21 21 28 28 29 29 +// 22 22 23 23 30 30 31 31 +// +// This is Morton order, a method for coalescing data accesses. It is used +// in a variety of contexts, from ray tracing acceleration structures, to +// nodal-point Laplacians, to sorting large lattices of atoms. +// +// Source: https://patents.google.com/patent/US11256518B2 +METAL_FUNC ushort2 morton_order(ushort thread_index_in_simdgroup) { + ushort lane_id = thread_index_in_simdgroup; + ushort quad_id = lane_id / 4; + + constexpr ushort QUADRANT_SPAN_M = 4; + constexpr ushort THREADS_PER_QUADRANT = 8; + ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M; + ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2); + ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant; + + ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4 + ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2 + ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant; + + return ushort2(N_in_simd, M_in_simd); +} + +// Indexes into an array of registers. +// +// Calls to this function are expected to be evaluated at compile time. The +// array indices transform into register offsets, which are embedded into the +// assembly code. +template +METAL_FUNC thread simdgroup_matrix_storage* get_sram( + thread simdgroup_matrix_storage *sram, + ushort sram_leading_dim, + ushort2 matrix_origin +) { + return sram + (matrix_origin.y / 8) * (sram_leading_dim / 8) + (matrix_origin.x / 8); +} + +// One multiply-accumulate loop iteration, or 8 dot products. +template< + typename T, + typename U = T, + ushort M_register, + ushort N_register +> +METAL_FUNC void multiply_accumulate( + const device T *A_src, + const device U *B_src, + thread simdgroup_matrix_storage *A_sram, + thread simdgroup_matrix_storage *B_sram, + thread simdgroup_matrix_storage *C_sram, + ushort k +) { +#pragma clang loop unroll(full) + for (ushort m = 0; m < M_register; m += 8) { + ushort2 origin(0, m); + auto A = get_sram(A_sram, 8, origin); + A->load(A_src, A_leading_dim, ushort2(k, m), A_trans); + } +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_register; n += 8) { + ushort2 origin(n, 0); + auto B = get_sram(B_sram, N_register, origin); + B->load(B_src, B_leading_dim, ushort2(n, k), B_trans); + } +#pragma clang loop unroll(full) + for (ushort m = 0; m < M_register; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_register; n += 8) { + auto A = get_sram(A_sram, 8, ushort2(0, m)); + auto B = get_sram(B_sram, N_register, ushort2(n, 0)); + auto C = get_sram(C_sram, N_register, ushort2(n, m)); + C->multiply(*A, *B); + } + } +} + +// One multiply-accumulate loop iteration, or 8 dot products. +template< + typename T, + typename U = T, + ushort M_register, + ushort N_register +> +METAL_FUNC void multiply_accumulate( + const threadgroup T *A_src, + const threadgroup U *B_src, + thread simdgroup_matrix_storage *A_sram, + thread simdgroup_matrix_storage *B_sram, + thread simdgroup_matrix_storage *C_sram, + ushort k +) { +#pragma clang loop unroll(full) + for (ushort m = 0; m < M_register; m += 8) { + ushort2 origin(0, m); + auto A = get_sram(A_sram, 8, origin); + A->load(A_src, A_leading_dim, ushort2(k, m), A_trans); + } +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_register; n += 8) { + ushort2 origin(n, 0); + auto B = get_sram(B_sram, N_register, origin); + B->load(B_src, B_leading_dim, ushort2(n, k), B_trans); + } +#pragma clang loop unroll(full) + for (ushort m = 0; m < M_register; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < N_register; n += 8) { + auto A = get_sram(A_sram, 8, ushort2(0, m)); + auto B = get_sram(B_sram, N_register, ushort2(n, 0)); + auto C = get_sram(C_sram, N_register, ushort2(n, m)); + C->multiply(*A, *B); + } + } +} + +// Metal function arguments. +// +// A: the left-hand side matrix +// - dimensions: M x K +// K x M (transposed) +// - memory precision: memA +// - register precision: regA +// +// B: the right-hand side matrix +// - dimensions: K x N +// N x K (transposed) +// - memory precision: memB +// - register precision: regB +// +// C: the output matrix, alternatively the dot product accumulator +// - dimensions: M x N +// - memory precision: memC +// - register precision: regC +// +// threadgroup_block: the chunk of threadgroup memory allocated at runtime +// - ideally 10 KB or less +// - precision: void/8-bit integer to make the pointer arithmetic more legible +template < + typename T, + typename U = T, + ushort M_block_dim, + ushort N_block_dim, + ushort K_block_dim, + ushort M_split, + ushort N_split +> +void gemm_impl( + device T *A [[buffer(0)]], + device U *B [[buffer(1)]], + device U *C [[buffer(2)]], + + threadgroup uchar *threadgroup_block [[threadgroup(0)]], + + uint3 gid [[threadgroup_position_in_grid]], + ushort sidx [[simdgroup_index_in_threadgroup]], + ushort lane_id [[thread_index_in_simdgroup]] +) { + constexpr ushort M_register = M_block_dim / M_split; + constexpr ushort N_register = N_block_dim / N_split; + constexpr ushort threadgroup_size = 32 * M_split * N_split; + + const ushort iteration_start = prefer_async_copy ? 0 : (K - (K % K_group)); + + // Find the number of elements in the final block. If the matrix + // dimensions are perfectly divisibly by block dimensions, we don't want + // this value to be zero. The final block is a full block. + const uint M_remainder = (M % M_register == 0) + ? M_register : M % M_register; + const ushort N_remainder = (N % N_register == 0) + ? N_register : N % N_register; + const ushort K_remainder = (K % K_group == 0) + ? K_group : K % K_group; + const ushort K_remainder_padded = (K_remainder + 7) / 8 * 8; + + // Shift the final block, so it doesn't access out-of-bounds memory. + const ushort M_shift = (M < M_group) ? 0 : M_register - M_remainder; + const ushort N_shift = (N < N_group) ? 0 : N_register - N_remainder; + + auto A_block = (threadgroup T*)(threadgroup_block); + auto B_block = (threadgroup U*)(threadgroup_block + (M*K)); + ushort2 sid(sidx % N_split, sidx / N_split); + ushort2 morton_offset = morton_order(lane_id); + + // Return early if the SIMD is out of bounds. + // + // There could be some threadgroups where the matrix edge cuts straight + // through the middle of the block. SIMDs on the right or bottom of the + // dividing line must be stopped from causing out-of-bounds accesses. This is + // the reason for the early exit. + uint M_offset = gid.y * M_group; + uint N_offset = gid.x * N_group; + if (M_offset + sid.y * M_register >= M || + N_offset + sid.x * N_register >= N) { + return; + } + ushort2 offset_in_group(sid.x * M_register + morton_offset.x, + sid.y * N_register + morton_offset.y); + + // Shift the matrix block within bounds, if possible. + if ((M_shift != 0) && (gid.y * M_group >= M_edge)) { + M_offset -= M_shift; + } + if ((N_shift != 0) && (gid.x * N_group >= N_edge)) { + N_offset -= N_shift; + } + + simdgroup_matrix_storage C_sram[(M_register / 8) * (N_register / 8)]; + + // Initialize the accumulator. + #pragma clang loop unroll(full) + for (ushort m = 0; m < M_register; m += 8) { + #pragma clang loop unroll(full) + for (ushort n = 0; n < N_register; n += 8) { + ushort2 origin(n, m); + auto C = get_sram(C_sram, N_register, origin); + *C = simdgroup_matrix_storage(0); + } + } + + // Perform the iterations where async copy is avoided. + for (uint k = 0; k < iteration_start; k += 8) { + uint2 A_offset(k, M_offset); + uint2 B_offset(N_offset, k); + A_offset += uint2(morton_offset.x, offset_in_group.y); + B_offset += uint2(offset_in_group.x, morton_offset.y); + + auto A_src = simdgroup_matrix_storage::apply_offset( + A, A_leading_dim, A_offset, A_trans); + auto B_src = simdgroup_matrix_storage::apply_offset( + B, N, B_offset, B_trans); + + simdgroup_matrix_storage A_sram[M_register / 8]; + simdgroup_matrix_storage B_sram[N_register / 8]; + multiply_accumulate( + A_src, B_src, A_sram, B_sram, C_sram, 0); + } + + // Perform the iterations where async copy is used. + for (uint k = iteration_start; k < K; k += K_group) { + // Launch an async copy from device to threadgroup memory. + if (sidx == 0) { + uint2 A_offset(k, M_offset); + uint2 B_offset(N_offset, k); + auto A_src = simdgroup_matrix_storage::apply_offset( + A, A_leading_dim, A_offset, A_trans); + auto B_src = simdgroup_matrix_storage::apply_offset( + B, N, B_offset, B_trans); + + ushort M_tile_dimension = min(uint(M_group), M - M_offset); + ushort N_tile_dimension = min(uint(N_group), N - N_offset); + ushort K_tile_dimension = min(uint(K_group), K - k); + ushort K_tile_padded = min(uint(K_group), (K + K_remainder_padded - K_remainder) - k); + + ushort2 A_tile_src(K_tile_dimension, M_tile_dimension); + ushort2 B_tile_src(N_tile_dimension, K_tile_dimension); + ushort2 A_tile_dst(K_tile_padded, M_tile_dimension); + ushort2 B_tile_dst(N_tile_dimension, K_tile_padded); + + simdgroup_event events[2]; + events[0].async_copy(A_block, A_leading_block_dim, A_tile_dst, + A_src, A_leading_dim, A_tile_src, A_trans); + events[1].async_copy(B_block, B_leading_block_dim, B_tile_dst, + B_src, B_leading_dim, B_tile_src, B_trans); + simdgroup_event::wait(2, events); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + ushort2 A_block_offset(morton_offset.x, offset_in_group.y); + ushort2 B_block_offset(offset_in_group.x, morton_offset.y); + auto A_block_src = simdgroup_matrix_storage::apply_offset( + A_block, A_leading_block_dim, A_block_offset, A_trans); + auto B_block_src = simdgroup_matrix_storage::apply_offset( + B_block, B_leading_block_dim, B_block_offset, B_trans); + + simdgroup_matrix_storage A_sram[(M_register / 8) * (K_block_dim / 8)]; + simdgroup_matrix_storage B_sram[(K_block_dim / 8) * (N_register / 8)]; + #pragma clang loop unroll(full) + for (ushort k = 0; k < K_remainder_padded; k += 8) { + multiply_accumulate( + A_block_src, B_block_src, A_sram, B_sram, C_sram, k); + } + + // Will there be any iterations after this one? + if (k + K_group < K) { + // If so, we haven't reached the edge of either input matrix yet. + #pragma clang loop unroll(full) + for (ushort k = K_remainder_padded; k < K_group; k += 8) { + multiply_accumulate( + A_block_src, B_block_src, A_sram, B_sram, C_sram, k); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + if (!prefer_async_copy && (M >= M_group) && (N >= N_group)) { + // Fast path for matrices that qualify. + uint2 C_offset(N_offset + offset_in_group.x, + M_offset + offset_in_group.y); + auto C_dst = simdgroup_matrix_storage::apply_offset( + C, N, C_offset); + + // Write the accumulator to device memory. + #pragma clang loop unroll(full) + for (ushort m = 0; m < M_register; m += 8) { + #pragma clang loop unroll(full) + for (ushort n = 0; n < N_register; n += 8) { + ushort2 origin(n, m); + auto C = get_sram(C_sram, N_register, origin); + C->store(C_dst, N, origin); + } + } + } else { + // Slow path for when memory must be handled more carefully. + auto C_block = (threadgroup U*)(threadgroup_block); + auto C_block_dst = simdgroup_matrix_storage::apply_offset( + C_block, N_group, offset_in_group); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write the accumulator to threadgroup memory. + #pragma clang loop unroll(full) + for (ushort m = 0; m < M_register; m += 8) { + #pragma clang loop unroll(full) + for (ushort n = 0; n < N_register; n += 8) { + ushort2 origin(n, m); + auto C = get_sram(C_sram, N_register, origin); + C->store(C_block_dst, N_group, origin); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Launch the async copy from threadgroup to device memory. + if (sidx == 0) { + uint2 C_offset(gid.x * N_group, gid.y * M_group); + ushort2 C_tile(min(uint(N_group), N - C_offset.x), + min(uint(M_group), M - C_offset.y)); + auto C_dst = simdgroup_matrix_storage::apply_offset( + C, N, C_offset); + + // If we shift successfully, the garbage zone moves from the bottom right + // to the top left. + if ((M_shift != 0) || (N_shift != 0)) { + ushort2 C_block_shift(0, 0); + if ((M_shift != 0) && (C_offset.y >= M_edge)) { + C_block_shift.y = M_shift; + } + if ((N_shift != 0) && (C_offset.x >= N_edge)) { + C_block_shift.x = N_shift; + } + C_block = simdgroup_matrix_storage::apply_offset( + C_block, N_group, C_block_shift); + } + + simdgroup_event event; + event.async_copy(C_dst, N, C_tile, C_block, N_group, C_tile); + } + } +} + +kernel void hgemm( + device half *A [[buffer(0)]], + device half *B [[buffer(1)]], + device half *C [[buffer(2)]], + + threadgroup uchar *threadgroup_block [[threadgroup(0)]], + + uint3 gid [[threadgroup_position_in_grid]], + ushort sidx [[simdgroup_index_in_threadgroup]], + ushort lane_id [[thread_index_in_simdgroup]] +) { + if (ideal_grouping) { + gemm_impl< + half, + half, + 32, + 32, + 32, + 1, + 1 + >( + A, B, C, threadgroup_block, gid, sidx, lane_id + ); + } else { + gemm_impl< + half, + half, + 48, + 48, + 32, + 1, + 1 + >( + A, B, C, threadgroup_block, gid, sidx, lane_id + ); + } +} + +kernel void sgemm( + device float *A [[buffer(0)]], + device float *B [[buffer(1)]], + device float *C [[buffer(2)]], + + threadgroup uchar *threadgroup_block [[threadgroup(0)]], + + uint3 gid [[threadgroup_position_in_grid]], + ushort sidx [[simdgroup_index_in_threadgroup]], + ushort lane_id [[thread_index_in_simdgroup]] +) { + if (prefer_async_copy) { + // TODO: figure out correct splits + if (ideal_grouping) { + gemm_impl< + float, + float, + 32, + 32, + 32, + 2, + 2 + >( + A, B, C, threadgroup_block, gid, sidx, lane_id + ); + } else { + gemm_impl< + float, + float, + 48, + 48, + 24, + 2, + 2 + >( + A, B, C, threadgroup_block, gid, sidx, lane_id + ); + } + } else { + // TODO: figure out correct splits + constexpr ushort M_split = 1; + constexpr ushort N_split = 1; + if (ideal_grouping) { + gemm_impl< + float, + float, + 32, + 32, + 32, + M_split, + N_split + >( + A, B, C, threadgroup_block, gid, sidx, lane_id + ); + } else { + gemm_impl< + float, + float, + 48, + 48, + 24, + M_split, + N_split + >( + A, B, C, threadgroup_block, gid, sidx, lane_id + ); + } + } +} diff --git a/candle-metal-kernels/src/kernels/matrix_storage.metal b/candle-metal-kernels/src/kernels/matrix_storage.metal new file mode 100644 index 00000000..0dfc75cf --- /dev/null +++ b/candle-metal-kernels/src/kernels/matrix_storage.metal @@ -0,0 +1,243 @@ +// -*- Metal -*- +//===-- metal_simdgroup_matrix_storage ------------------------------------===// +// Copyright (c) 2024 Philip Turner. See MIT LICENSE +//===----------------------------------------------------------------------===// + +#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE +#define __METAL_SIMDGROUP_MATRIX_STORAGE + +#pragma METAL internals : enable +namespace metal +{ + template + struct simdgroup_matrix_storage { + typedef vec storage_type; + + storage_type t; + + METAL_FUNC thread vec* thread_elements() thread { + return reinterpret_cast*>(&t); + } + + METAL_FUNC simdgroup_matrix_storage() thread = default; + + METAL_FUNC simdgroup_matrix_storage(vec thread_elements) thread { + *(this->thread_elements()) = thread_elements; + } + + METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y; + } else { + return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x; + } + } + + METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + return src + matrix_origin.x * elements_per_row + matrix_origin.y; + } else { + return src + matrix_origin.y * elements_per_row + matrix_origin.x; + } + } + + template + METAL_FUNC void load(const device U *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y); + uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y); + U memoryForm0 = src[address0]; + U memoryForm1 = src[address1]; + ((thread T*)thread_elements())[0] = T(memoryForm0); + ((thread T*)thread_elements())[1] = T(memoryForm1); + } else if (elements_per_row % 2 != 0) { + uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0); + uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1); + U memoryForm0 = src[address0]; + U memoryForm1 = src[address1]; + ((thread T*)thread_elements())[0] = T(memoryForm0); + ((thread T*)thread_elements())[1] = T(memoryForm1); + } else { + auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0); + vec memoryForm = *(const device vec*)(src + combinedAddress); + *(thread_elements()) = vec(memoryForm); + } + } + + // WARNING: 'T' must be 'float'. + METAL_FUNC void load_bfloat(const device bfloat *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y); + uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y); + bfloat memoryForm0 = src[address0]; + bfloat memoryForm1 = src[address1]; + + bfloat4 registerForm = *(thread bfloat4*)(thread_elements()); + registerForm[1] = memoryForm0; + registerForm[3] = memoryForm1; + ((thread bfloat4*)thread_elements())[0] = registerForm; + } else { + auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0); + bfloat2 memoryForm = *(const device packed_bfloat2*)(src + combinedAddress); + + bfloat4 registerForm = *(thread bfloat4*)(thread_elements()); + ((thread float*)®isterForm)[1] = *(thread float*)(&memoryForm); + ((thread bfloat*)®isterForm)[1] = memoryForm[0]; + ((thread bfloat4*)thread_elements())[0] = registerForm; + } + } + + template + METAL_FUNC void load(const threadgroup U *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y); + ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y); + U memoryForm0 = src[address0]; + U memoryForm1 = src[address1]; + ((thread T*)thread_elements())[0] = T(memoryForm0); + ((thread T*)thread_elements())[1] = T(memoryForm1); + } else if (elements_per_row % 2 != 0) { + ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0); + ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1); + U memoryForm0 = src[address0]; + U memoryForm1 = src[address1]; + ((thread T*)thread_elements())[0] = T(memoryForm0); + ((thread T*)thread_elements())[1] = T(memoryForm1); + } else { + auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0); + vec memoryForm = *(const threadgroup vec*)(src + combinedAddress); + *(thread_elements()) = vec(memoryForm); + } + } + + // WARNING: 'T' must be 'float'. + METAL_FUNC void load_bfloat(const threadgroup bfloat *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y); + ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y); + bfloat memoryForm0 = src[address0]; + bfloat memoryForm1 = src[address1]; + + bfloat4 registerForm = *(thread bfloat4*)(thread_elements()); + registerForm[1] = memoryForm0; + registerForm[3] = memoryForm1; + ((thread bfloat4*)thread_elements())[0] = registerForm; + } else { + auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0); + bfloat2 memoryForm = *(const threadgroup packed_bfloat2*)(src + combinedAddress); + + bfloat4 registerForm = *(thread bfloat4*)(thread_elements()); + ((thread float*)®isterForm)[1] = *(thread float*)(&memoryForm); + ((thread bfloat*)®isterForm)[1] = memoryForm[0]; + ((thread bfloat4*)thread_elements())[0] = registerForm; + } + } + + template + METAL_FUNC void store(device U *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y); + uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y); + T registerForm0 = ((thread T*)thread_elements())[0]; + T registerForm1 = ((thread T*)thread_elements())[1]; + dst[address0] = U(registerForm0); + dst[address1] = U(registerForm1); + } else if (elements_per_row % 2 != 0) { + uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0); + uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1); + T registerForm0 = ((thread T*)thread_elements())[0]; + T registerForm1 = ((thread T*)thread_elements())[1]; + dst[address0] = U(registerForm0); + dst[address1] = U(registerForm1); + } else { + auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0); + vec registerForm = *(thread_elements()); + *(device vec*)(dst + combinedAddress) = vec(registerForm); + } + } + + // WARNING: 'T' must be 'float'. + METAL_FUNC void store_bfloat(device bfloat *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y); + uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y); + bfloat4 registerForm = *(thread bfloat4*)(thread_elements()); + registerForm[2] = registerForm[1]; + dst[address0] = registerForm[2]; + dst[address1] = registerForm[3]; + } else if (elements_per_row % 2 != 0) { + uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0); + uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1); + bfloat4 registerForm = *(thread bfloat4*)(thread_elements()); + registerForm[2] = registerForm[1]; + dst[address0] = registerForm[2]; + dst[address1] = registerForm[3]; + } else { + auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0); + bfloat4 registerForm = *(thread bfloat4*)(thread_elements()); + registerForm[2] = registerForm[1]; + float memoryForm = ((thread float*)®isterForm)[1]; + *(device float*)(dst + combinedAddress) = memoryForm; + } + } + + template + METAL_FUNC void store(threadgroup U *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y); + ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y); + T registerForm0 = ((thread T*)thread_elements())[0]; + T registerForm1 = ((thread T*)thread_elements())[1]; + dst[address0] = U(registerForm0); + dst[address1] = U(registerForm1); + } else if (elements_per_row % 2 != 0) { + ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0); + ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1); + T registerForm0 = ((thread T*)thread_elements())[0]; + T registerForm1 = ((thread T*)thread_elements())[1]; + dst[address0] = U(registerForm0); + dst[address1] = U(registerForm1); + } else { + auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0); + vec registerForm = *(thread_elements()); + *(threadgroup vec*)(dst + combinedAddress) = vec(registerForm); + } + } + + // WARNING: 'T' must be 'float'. + METAL_FUNC void store_bfloat(threadgroup bfloat *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y); + ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y); + bfloat4 registerForm = *(thread bfloat4*)(thread_elements()); + registerForm[2] = registerForm[1]; + dst[address0] = registerForm[2]; + dst[address1] = registerForm[3]; + } else if (elements_per_row % 2 != 0) { + ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0); + ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1); + bfloat4 registerForm = *(thread bfloat4*)(thread_elements()); + registerForm[2] = registerForm[1]; + dst[address0] = registerForm[2]; + dst[address1] = registerForm[3]; + } else { + auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0); + bfloat4 registerForm = *(thread bfloat4*)(thread_elements()); + registerForm[2] = registerForm[1]; + float memoryForm = ((thread float*)®isterForm)[1]; + *(threadgroup float*)(dst + combinedAddress) = memoryForm; + } + } + + template + METAL_FUNC void multiply(simdgroup_matrix_storage a, simdgroup_matrix_storage b, bool accumulate = true) { + if (!accumulate) { + *(thread_elements()) = vec(0); + } + t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage::storage_type()); + } + }; +} // namespace metal +#pragma METAL internals : disable + +#endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 65d7f62b..8c98de13 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,6 @@ use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function, - FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, + FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, MTLGPUFamily, }; use std::collections::HashMap; use std::ffi::c_void; @@ -22,7 +22,7 @@ const RANDOM: &str = include_str!("kernels/random.metal"); const QUANTIZED: &str = include_str!("kernels/quantized.metal"); const SORT: &str = include_str!("kernels/sort.metal"); const MFA: &[u8] = include_bytes!("libraries/libMetalFlashAttention.metallib"); -const CANDLE: &[u8] = include_bytes!("libraries/libMetalFlashAttention.metallib"); +const CANDLE: &[u8] = include_bytes!("libraries/candle.metallib"); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { @@ -1473,6 +1473,21 @@ pub fn call_gemm( rhs_buffer: &Buffer, output: &Buffer, ) -> Result<(), MetalKernelError> { + let prefer_async_copy = !device.supports_family(MTLGPUFamily::Apple9); + let mut ideal_grouping = false; + /* + let mut actual_groups = 1; + actual_groups *= divide(m, 48); + actual_groups *= divide(n, 48); + actual_groups *= b; + + let core_count = get_device_core_count(device); + let ideal_grouping = if name == "sgemm" { + actual_groups <= core_count * 6 + } else { + actual_groups <= core_count * 9 + }; + */ assert!(rhs_stride.len() >= 2); assert!(lhs_stride.len() >= 2); let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; @@ -1543,14 +1558,16 @@ pub fn call_gemm( (113, Value::Bool(false)), (50_000, Value::Bool(false)), // End garbage - (200, Value::U16(m_simd)), - (201, Value::U16(n_simd)), - (202, Value::U16(k_simd)), + (200, Value::U16(32)), + (201, Value::U16(32)), + (202, Value::U16(32)), + (206, Value::Bool(prefer_async_copy)), + (207, Value::Bool(ideal_grouping)), (210, Value::U16(m_splits)), (211, Value::U16(n_splits)), (50_001, Value::Bool(fused_bias)), ])); - let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; + let pipeline = kernels.load_pipeline_with_constants(device, Source::Candle, name, constants)?; let m_group = m_simd * m_splits; let n_group = n_simd * n_splits; diff --git a/candle-metal-kernels/src/libraries/candle.metallib b/candle-metal-kernels/src/libraries/candle.metallib index 80bc7369..1a9df376 100644 Binary files a/candle-metal-kernels/src/libraries/candle.metallib and b/candle-metal-kernels/src/libraries/candle.metallib differ