mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Tmp gemm.
This commit is contained in:
499
candle-metal-kernels/src/GEMM.metal
Normal file
499
candle-metal-kernels/src/GEMM.metal
Normal file
@ -0,0 +1,499 @@
|
||||
//
|
||||
// GEMM.metal
|
||||
// MetalFlashAttention
|
||||
//
|
||||
// Created by Philip Turner on 6/23/23.
|
||||
//
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include "metal_data_type"
|
||||
#include "metal_simdgroup_event"
|
||||
#include "metal_simdgroup_matrix_storage"
|
||||
using namespace metal;
|
||||
|
||||
// MARK: - Function Constants
|
||||
|
||||
// Dimensions of each matrix.
|
||||
constant uint M [[function_constant(0)]];
|
||||
constant uint N [[function_constant(1)]];
|
||||
constant uint K [[function_constant(2)]];
|
||||
|
||||
// Whether each matrix is transposed.
|
||||
constant bool A_trans [[function_constant(10)]];
|
||||
constant bool B_trans [[function_constant(11)]];
|
||||
constant bool D_trans [[function_constant(13)]];
|
||||
constant uint A_leading_dim = A_trans ? M : K;
|
||||
constant uint B_leading_dim = B_trans ? K : N;
|
||||
|
||||
// Alpha and beta constants from BLAS.
|
||||
constant float alpha [[function_constant(20)]];
|
||||
constant float beta [[function_constant(21)]];
|
||||
|
||||
constant bool batched [[function_constant(100)]];
|
||||
constant bool fused_activation [[function_constant(101)]];
|
||||
constant bool fused_bias [[function_constant(50001)]]; // 102
|
||||
constant bool use_bias = is_function_constant_defined(fused_bias) ? fused_bias : false;
|
||||
constant bool use_activation_function = fused_activation && !fused_bias;
|
||||
constant bool use_activation = use_bias || use_activation_function;
|
||||
constant bool batched_activation_function = batched && use_activation_function;
|
||||
|
||||
constant ushort M_simd [[function_constant(200)]];
|
||||
constant ushort N_simd [[function_constant(201)]];
|
||||
constant ushort K_simd [[function_constant(202)]];
|
||||
|
||||
// Elide work on the edge when matrix dimension < SRAM block dimension.
|
||||
constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd);
|
||||
constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd);
|
||||
constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd;
|
||||
constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd;
|
||||
|
||||
constant ushort M_splits [[function_constant(210)]];
|
||||
constant ushort N_splits [[function_constant(211)]];
|
||||
|
||||
constant ushort M_group = M_simd * M_splits;
|
||||
constant ushort N_group = N_simd * N_splits;
|
||||
constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd);
|
||||
constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group);
|
||||
|
||||
// There is no padding for M reads/writes.
|
||||
// There is no padding for N reads/writes.
|
||||
constant ushort K_simd_unpadded = (K % K_simd == 0) ? K_simd : (K % K_simd);
|
||||
constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8;
|
||||
|
||||
constant ushort A_sram_length = (M_simd / 8) * 1;
|
||||
constant ushort B_sram_length = 1 * (N_simd / 8);
|
||||
constant ushort A_block_length = M_group * K_simd;
|
||||
|
||||
// Threadgroup block must fit entire C accumulator and partial sums.
|
||||
constant ushort A_sram_offset = 0;
|
||||
constant ushort B_sram_offset = A_sram_offset + A_sram_length;
|
||||
constant ushort C_sram_offset = B_sram_offset + B_sram_length;
|
||||
constant ushort A_block_offset = 0;
|
||||
constant ushort B_block_offset = A_block_offset + A_block_length;
|
||||
|
||||
// MARK: - Utilities
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC thread simdgroup_matrix_storage<T>* A_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
|
||||
// A_sram[M_simd][8]
|
||||
return sram + A_sram_offset + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC thread simdgroup_matrix_storage<T>* B_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
|
||||
// A_sram[8][N_simd]
|
||||
return sram + B_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC thread simdgroup_matrix_storage<T>* C_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
|
||||
// C_sram[M_simd][N_simd]
|
||||
return sram + C_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void prefetch(threadgroup T *A_block, device T *A,
|
||||
ushort2 A_tile_src, uint2 A_offset,
|
||||
threadgroup T *B_block, device T *B,
|
||||
ushort2 B_tile_src, uint2 B_offset, uint k)
|
||||
{
|
||||
A_tile_src.x = min(uint(K_simd), K - k);
|
||||
B_tile_src.y = min(uint(K_simd), K - k);
|
||||
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||
auto B_src = simdgroup_matrix_storage<T>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||
|
||||
// Rounded-up ceiling for the threadgroup block.
|
||||
const uint K_edge_floor = K - K_simd_unpadded;
|
||||
const uint K_edge_ceil = K_edge_floor + K_simd_padded;
|
||||
ushort K_padded;
|
||||
if (K_edge_floor == K_simd) {
|
||||
K_padded = K_simd;
|
||||
} else {
|
||||
K_padded = min(uint(K_simd), K_edge_ceil - k);
|
||||
}
|
||||
ushort2 A_tile_dst(K_padded, A_tile_src.y);
|
||||
ushort2 B_tile_dst(B_tile_src.x, K_padded);
|
||||
|
||||
simdgroup_event events[2];
|
||||
events[0].async_copy(A_block, A_block_leading_dim, A_tile_dst, A_src, A_leading_dim, A_tile_src, A_trans);
|
||||
events[1].async_copy(B_block, B_block_leading_dim, B_tile_dst, B_src, B_leading_dim, B_tile_src, B_trans);
|
||||
simdgroup_event::wait(2, events);
|
||||
}
|
||||
|
||||
// One iteration of the MACC loop, effectively k=8 iterations.
|
||||
template <typename T>
|
||||
METAL_FUNC void multiply_accumulate(thread simdgroup_matrix_storage<T> *sram,
|
||||
const threadgroup T *A_block,
|
||||
const threadgroup T *B_block,
|
||||
bool accumulate = true)
|
||||
{
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
ushort2 origin(0, m);
|
||||
A_sram(sram, origin)->load(A_block, A_block_leading_dim, origin, A_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, 0);
|
||||
B_sram(sram, origin)->load(B_block, B_block_leading_dim, origin, B_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
auto A = A_sram(sram, ushort2(0, m));
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
auto B = B_sram(sram, ushort2(n, 0));
|
||||
auto C = C_sram(sram, ushort2(n, m));
|
||||
C->multiply(*A, *B, accumulate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void partial_store(thread simdgroup_matrix_storage<T> *sram,
|
||||
threadgroup T *C_block, bool is_k_summation)
|
||||
{
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
if (is_k_summation) {
|
||||
C_sram(sram, origin)->store(C_block, N_simd, origin);
|
||||
} else {
|
||||
C_sram(sram, origin)->store(C_block, N_group, origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void partial_accumulate(thread simdgroup_matrix_storage<T> *sram,
|
||||
threadgroup T *C_block, bool is_k_summation)
|
||||
{
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
auto B = B_sram(sram, ushort2(n, 0));
|
||||
if (is_k_summation) {
|
||||
B->load(C_block, N_simd, origin);
|
||||
} else {
|
||||
B->load(C_block, N_group, origin);
|
||||
}
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
auto B = B_sram(sram, ushort2(n, 0));
|
||||
auto C = C_sram(sram, origin);
|
||||
if (is_k_summation) {
|
||||
C->thread_elements()[0] += B->thread_elements()[0];
|
||||
} else {
|
||||
float2 C_old = float2(B->thread_elements()[0]);
|
||||
float2 C_new = float2(C->thread_elements()[0]);
|
||||
C->thread_elements()[0] = vec<T, 2>(fast::fma(C_old, beta, C_new));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_access_accumulator(threadgroup T *C_block, device T *C,
|
||||
uint2 C_offset, bool is_store)
|
||||
{
|
||||
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
|
||||
min(uint(M_group), M - C_offset.y));
|
||||
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, C_offset);
|
||||
|
||||
simdgroup_event event;
|
||||
if (is_store) {
|
||||
event.async_copy(C_src, N, C_tile, C_block, N_group, C_tile);
|
||||
} else {
|
||||
event.async_copy(C_block, N_group, C_tile, C_src, N, C_tile);
|
||||
simdgroup_event::wait(1, &event);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage<T> *sram,
|
||||
device T *C, bool m_is_edge, bool n_is_edge)
|
||||
{
|
||||
const ushort m_start = (m_is_edge) ? M_modulo : 0;
|
||||
const ushort n_start = (n_is_edge) ? N_modulo : 0;
|
||||
const ushort m_end = (m_is_edge) ? M_simd : M_modulo;
|
||||
const ushort n_end = (n_is_edge) ? N_simd : N_modulo;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = m_start; m < m_end; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = n_start; n < n_end; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
C_sram(sram, origin)->store(C, N, origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct activation_functor {
|
||||
using function = void(threadgroup T *C,
|
||||
device void *D,
|
||||
uint grid_index_in_batch,
|
||||
uint2 matrix_origin,
|
||||
ushort2 tile_dimensions,
|
||||
ushort lane_id);
|
||||
|
||||
typedef visible_function_table<function> function_table;
|
||||
};
|
||||
|
||||
// MARK: - Kernels
|
||||
|
||||
template <typename T>
|
||||
void _gemm_impl(device T *A [[buffer(0)]],
|
||||
device T *B [[buffer(1)]],
|
||||
device T *C [[buffer(2)]],
|
||||
device void *D [[buffer(3), function_constant(use_activation)]],
|
||||
|
||||
threadgroup T *threadgroup_block [[threadgroup(0)]],
|
||||
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
typename activation_functor<T>::function_table table [[buffer(11), function_constant(use_activation_function)]],
|
||||
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]])
|
||||
{
|
||||
if (batched) {
|
||||
// TODO: Re-compute every inner loop iteration for FP64 accumulate.
|
||||
ulong3 offsets = matrix_offsets[gid.z].xyz;
|
||||
A = (device T*)((device uchar*)A + offsets[0]);
|
||||
B = (device T*)((device uchar*)B + offsets[1]);
|
||||
C = (device T*)((device uchar*)C + offsets[2]);
|
||||
}
|
||||
|
||||
simdgroup_matrix_storage<T> sram[1024];
|
||||
auto A_block = threadgroup_block + A_block_offset;
|
||||
auto B_block = threadgroup_block + B_block_offset;
|
||||
ushort2 sid(sidx % N_splits, sidx / N_splits);
|
||||
ushort2 offset_in_simd = simdgroup_matrix_storage<T>::offset(lane_id);
|
||||
|
||||
uint2 A_offset(0, gid.y * M_group);
|
||||
uint2 B_offset(gid.x * N_group, 0);
|
||||
{
|
||||
uint C_base_offset_x = B_offset.x + sid.x * N_simd;
|
||||
uint C_base_offset_y = A_offset.y + sid.y * M_simd;
|
||||
if (C_base_offset_x >= N || C_base_offset_y >= M) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ushort2 offset_in_group(sid.x * N_simd + offset_in_simd.x,
|
||||
sid.y * M_simd + offset_in_simd.y);
|
||||
|
||||
if (use_bias) {
|
||||
if (sidx == 0) {
|
||||
auto bias = (device T*)D;
|
||||
if (batched) {
|
||||
ulong offset = matrix_offsets[gid.z].w;
|
||||
bias = (device T*)((device uchar*)bias + offset);
|
||||
}
|
||||
|
||||
ushort bias_elements;
|
||||
if (is_function_constant_defined(D_trans) && D_trans) {
|
||||
bias += A_offset.y;
|
||||
bias_elements = min(uint(M_group), M - A_offset.y);
|
||||
} else {
|
||||
bias += B_offset.x;
|
||||
bias_elements = min(uint(N_group), N - B_offset.x);
|
||||
}
|
||||
|
||||
simdgroup_event event;
|
||||
event.async_copy(threadgroup_block, bias, bias_elements);
|
||||
simdgroup_event::wait(1, &event);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (is_function_constant_defined(D_trans) && D_trans) {
|
||||
auto bias = threadgroup_block + offset_in_group.y;
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
auto D = bias[m];
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
auto C = C_sram(sram, ushort2(n, m));
|
||||
*(C->thread_elements()) = D;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto bias = threadgroup_block + offset_in_group.x;
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
auto D = *(threadgroup vec<T, 2>*)(bias + n);
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
auto C = C_sram(sram, ushort2(n, m));
|
||||
*(C->thread_elements()) = D;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
ushort2 A_tile_src;
|
||||
ushort2 B_tile_src;
|
||||
if (sidx == 0) {
|
||||
A_tile_src.y = min(uint(M_group), M - A_offset.y);
|
||||
B_tile_src.x = min(uint(N_group), N - B_offset.x);
|
||||
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, 0);
|
||||
}
|
||||
|
||||
if (K > K_simd && !use_bias) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
*C_sram(sram, ushort2(n, m)) = simdgroup_matrix_storage<T>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (uint K_floor = 0; K_floor < K; K_floor += K_simd) {
|
||||
ushort2 A_block_offset(offset_in_simd.x, offset_in_group.y);
|
||||
ushort2 B_block_offset(offset_in_group.x, offset_in_simd.y);
|
||||
auto A_block_src = simdgroup_matrix_storage<T>::apply_offset(A_block, A_block_leading_dim, A_block_offset, A_trans);
|
||||
auto B_block_src = simdgroup_matrix_storage<T>::apply_offset(B_block, B_block_leading_dim, B_block_offset, B_trans);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort k = 0; k < K_simd_padded; k += 8) {
|
||||
bool accumulate = use_bias || !(K <= K_simd && k == 0);
|
||||
multiply_accumulate(sram, A_block_src, B_block_src, accumulate);
|
||||
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
|
||||
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
|
||||
}
|
||||
|
||||
if (K_floor + K_simd < K) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort k = K_simd_padded; k < K_simd; k += 8) {
|
||||
multiply_accumulate(sram, A_block_src, B_block_src);
|
||||
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
|
||||
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sidx == 0) {
|
||||
uint K_next = K_floor + K_simd;
|
||||
A_offset.x = K_next;
|
||||
B_offset.y = K_next;
|
||||
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, K_next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (alpha != 1) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int n = 0; n < N_padded; n += 8) {
|
||||
C_sram(sram, ushort2(n, m))->thread_elements()[0] *= alpha;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint2 C_offset(B_offset.x, A_offset.y);
|
||||
ushort2 C_block_offset = offset_in_group.xy;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (beta != 0) {
|
||||
if (sidx == 0) {
|
||||
async_access_accumulator(threadgroup_block, C, C_offset, false);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
|
||||
partial_accumulate(sram, C_block, false);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
if (use_activation_function) {
|
||||
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
|
||||
partial_store(sram, C_block, false);
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
uint grid_index_in_batch = (batched ? gid.z : 0);
|
||||
uint2 matrix_origin = C_offset + uint2(C_block_offset);
|
||||
matrix_origin &= ~7;
|
||||
ushort2 tile_dimensions(min(uint(N_group), N - matrix_origin.x),
|
||||
min(uint(M_group), M - matrix_origin.y));
|
||||
uint function_index = 0;
|
||||
if (batched_activation_function) {
|
||||
function_index = activation_function_offsets[gid.z];
|
||||
}
|
||||
table[function_index](C_block, D, grid_index_in_batch, matrix_origin, tile_dimensions, lane_id);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sidx == 0) {
|
||||
async_access_accumulator(threadgroup_block, C, C_offset, true);
|
||||
}
|
||||
return;
|
||||
} else if ((M % 8 != 0) || (N % 8 != 0)) {
|
||||
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
|
||||
partial_store(sram, C_block, false);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sidx == 0) {
|
||||
async_access_accumulator(threadgroup_block, C, C_offset, true);
|
||||
}
|
||||
} else {
|
||||
uint2 matrix_origin = C_offset + uint2(C_block_offset);
|
||||
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, matrix_origin);
|
||||
store_accumulator(sram, C_src, false, false);
|
||||
|
||||
const uint M_edge_floor = M - M % M_simd;
|
||||
const uint N_edge_floor = N - N % N_simd;
|
||||
if (matrix_origin.y < M_edge_floor) {
|
||||
store_accumulator(sram, C_src, true, false);
|
||||
}
|
||||
if (matrix_origin.x < N_edge_floor) {
|
||||
store_accumulator(sram, C_src, false, true);
|
||||
if (matrix_origin.y < M_edge_floor) {
|
||||
store_accumulator(sram, C_src, true, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void hgemm(device half *A [[buffer(0)]],
|
||||
device half *B [[buffer(1)]],
|
||||
device half *C [[buffer(2)]],
|
||||
device void *D [[buffer(3), function_constant(use_activation)]],
|
||||
|
||||
threadgroup half *threadgroup_block [[threadgroup(0)]],
|
||||
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
typename activation_functor<half>::function_table table [[buffer(11), function_constant(use_activation_function)]],
|
||||
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]])
|
||||
{
|
||||
_gemm_impl<half>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
|
||||
}
|
||||
|
||||
kernel void sgemm(device float *A [[buffer(0)]],
|
||||
device float *B [[buffer(1)]],
|
||||
device float *C [[buffer(2)]],
|
||||
device void *D [[buffer(3), function_constant(use_activation)]],
|
||||
|
||||
threadgroup float *threadgroup_block [[threadgroup(0)]],
|
||||
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
typename activation_functor<float>::function_table table [[buffer(11), function_constant(use_activation_function)]],
|
||||
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]])
|
||||
{
|
||||
_gemm_impl<float>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
|
||||
}
|
@ -14,6 +14,7 @@ const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const FLASH: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
|
||||
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||
let size = length as u64;
|
||||
@ -106,6 +107,7 @@ pub enum Source {
|
||||
Ternary,
|
||||
Cast,
|
||||
Reduce,
|
||||
Gemm,
|
||||
}
|
||||
|
||||
macro_rules! ops{
|
||||
@ -229,6 +231,7 @@ impl Kernels {
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Gemm => ""
|
||||
}
|
||||
}
|
||||
|
||||
@ -241,10 +244,17 @@ impl Kernels {
|
||||
if let Some(lib) = libraries.get(&source) {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let source_content = self.get_library_source(source);
|
||||
let lib = device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
|
||||
let lib = match source {
|
||||
Source::Gemm => device
|
||||
.new_library_with_data(FLASH)
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?,
|
||||
_souce => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
}
|
||||
};
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
}
|
||||
@ -291,6 +301,160 @@ impl Kernels {
|
||||
}
|
||||
}
|
||||
|
||||
enum Gemm{
|
||||
Float,
|
||||
Half,
|
||||
}
|
||||
|
||||
impl Gemm{
|
||||
fn size_of_dtype(&self) -> usize{
|
||||
match self{
|
||||
Gemm::Float => 4,
|
||||
Gemm::Half => 2,
|
||||
}
|
||||
}
|
||||
fn name(&self) -> &'static str{
|
||||
match self{
|
||||
Gemm::Float => "sgemm",
|
||||
Gemm::Half => "hgemm",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_gemm(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: Gemm,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Gemm, name.name())?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let config = gemm_config(&p);
|
||||
let m_group = config.m_group;
|
||||
let n_group = config.n_group;
|
||||
let k_simd = config.k_simd.value;
|
||||
let m_splits = config.m_splits.value;
|
||||
let n_splits = config.n_splits.value;
|
||||
|
||||
let size_of_dtype = name.size_of_dtype();
|
||||
let a_block_bytes = m_group * k_simd * size_of_dtype;
|
||||
let b_block_bytes = k_simd * n_group * size_of_dtype;
|
||||
let c_block_bytes = m_group * n_group * size_of_dtype;
|
||||
let mut thread_group_memory_length = a_block_bytes + b_block_bytes;
|
||||
|
||||
if p.m % 8 > 0 && p.n % 8 > 0 {
|
||||
thread_group_memory_length = max(thread_group_memory_length, c_block_bytes);
|
||||
}
|
||||
if p.fused_bias {
|
||||
let d_block_bytes = if p.d_trans {
|
||||
m_group * T::SIZE
|
||||
} else {
|
||||
n_group * T::SIZE
|
||||
};
|
||||
thread_group_memory_length = max(thread_group_memory_length, d_block_bytes);
|
||||
}
|
||||
|
||||
let grid_size = MTLSize::new(
|
||||
utils::ceil_divide(p.n, n_group)?,
|
||||
utils::ceil_divide(p.m, m_group)?,
|
||||
1,
|
||||
);
|
||||
|
||||
let group_size = MTLSize::new((32 * m_splits * n_splits) as NSUInteger, 1, 1);
|
||||
|
||||
let mut flags = 0;
|
||||
if p.batched {
|
||||
flags |= 0x1;
|
||||
}
|
||||
if p.fused_activation {
|
||||
flags |= 0x2;
|
||||
}
|
||||
if p.fused_bias {
|
||||
flags |= 0x4;
|
||||
}
|
||||
|
||||
let constant_values = config.create_function_constant_values();
|
||||
let function = lib.get_function(T::FN_NAME, Some(constant_values))?;
|
||||
encoder
|
||||
.set_threadgroup_memory_length(0, memory_length);
|
||||
|
||||
encoder.use_resources(&[a.buffer(), b.buffer()], MTLResourceUsage::Read);
|
||||
encoder.use_resource(c.buffer(), MTLResourceUsage::Write);
|
||||
|
||||
if let Some(d) = d {
|
||||
encoder.use_resource(d.buffer(), MTLResourceUsage::Read);
|
||||
}
|
||||
|
||||
encoder.set_buffers(
|
||||
0,
|
||||
&[Some(a.buffer()), Some(b.buffer()), Some(c.buffer())],
|
||||
&[0; 3],
|
||||
);
|
||||
if let Some(d) = d {
|
||||
encoder.set_buffer(3, Some(d.buffer()), 0);
|
||||
}
|
||||
|
||||
let mut grid_z = 1;
|
||||
if pipeline.flags() & 0x1 > 0 {
|
||||
panic!("Batched gemm not implemented yet");
|
||||
// let batch_dimensions_a = tensors.a.shape.dropLast(2);
|
||||
// let batch_dimensions_b = tensors.b.shape.dropLast(2);
|
||||
// let batch_dimensions_c = tensors.c.shape.dropLast(2);
|
||||
// assert!(batch_dimensions_a.iter().product() > 0);
|
||||
// assert!(
|
||||
// batch_dimensions_b.iter().product() == 1 ||
|
||||
// batch_dimensions_b == batch_dimensions_a);
|
||||
// assert!(batch_dimensions_a == batch_dimensions_c);
|
||||
// grid_z = batch_dimensions_a.iter().product();
|
||||
//
|
||||
// if let Some(batch_dimensions_d) = tensors.d { .shape.dropLast(1)
|
||||
// assert!(
|
||||
// batch_dimensions_d.reduce(1, *) == 1 ||
|
||||
// batch_dimensions_d == batch_dimensions_a);
|
||||
// }
|
||||
//
|
||||
// // Mixed precision will cause undefined behavior.
|
||||
// let element_size = mem::size_of::<T>();
|
||||
// let byte_stride = |shape: Vec<u64>| -> u32 {
|
||||
// let rank = shape.len();
|
||||
// let mut output = element_size * shape[rank - 2] * shape[rank - 1];
|
||||
// if shape.dropLast(2).product() == 1 {
|
||||
// output = 0
|
||||
// }
|
||||
// output
|
||||
// } as u32;
|
||||
// let byte_stride_a = byte_stride(tensors.a.shape);
|
||||
// let byte_stride_b = byte_stride(tensors.b.shape);
|
||||
// let byte_stride_c = byte_stride(tensors.c.shape);
|
||||
//
|
||||
// var byteStrideD = 0
|
||||
// if let shapeD = tensors.d?.shape {
|
||||
// let rank = shapeD.count
|
||||
// byteStrideD = element_size * shapeD[rank - 1]
|
||||
// if shapeD.dropLast(1).reduce(1, *) == 1 {
|
||||
// byteStrideD = 0
|
||||
// }
|
||||
// }
|
||||
// withUnsafeTemporaryAllocation(
|
||||
// of: SIMD4<UInt64>.self, capacity: gridZ
|
||||
// ) { buffer in
|
||||
// for i in 0..<buffer.count {
|
||||
// buffer[i] = SIMD4(
|
||||
// UInt64(truncatingIfNeeded: i * byte_stride_a),
|
||||
// UInt64(truncatingIfNeeded: i * byte_stride_b),
|
||||
// UInt64(truncatingIfNeeded: i * byte_stride_c),
|
||||
// UInt64(truncatingIfNeeded: i * byteStrideD))
|
||||
// }
|
||||
//
|
||||
// let bufferLength = buffer.count * MemoryLayout<SIMD3<UInt64>>.stride
|
||||
// assert(MemoryLayout<SIMD3<UInt64>>.stride == 8 * 4)
|
||||
// encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
|
||||
// }
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_unary_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -481,7 +645,7 @@ pub fn call_reduce_contiguous(
|
||||
length: usize,
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
@ -490,7 +654,10 @@ pub fn call_reduce_contiguous(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, elements_to_sum, (input,input_offset), output));
|
||||
set_params!(
|
||||
encoder,
|
||||
(length, elements_to_sum, (input, input_offset), output)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
@ -742,7 +909,7 @@ mod tests {
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
call_unary_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -766,7 +933,7 @@ mod tests {
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let left = new_buffer(&device, x);
|
||||
let right = new_buffer(&device, y);
|
||||
let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
||||
let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
|
||||
call_binary_contiguous(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -794,7 +961,7 @@ mod tests {
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
let kernels = Kernels::new();
|
||||
call_unary_strided(
|
||||
&device,
|
||||
@ -892,7 +1059,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn cos_strided_random() {
|
||||
let v: Vec<_> = (0..10_000).map(|i| rand::random::<f32>()).collect();
|
||||
let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect();
|
||||
let shape = vec![5_000, 2];
|
||||
let strides = vec![1, 5_000];
|
||||
let offset = 0;
|
||||
@ -934,7 +1101,7 @@ mod tests {
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
|
||||
call_cast_contiguous(
|
||||
&device,
|
||||
@ -973,7 +1140,7 @@ mod tests {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
|
||||
let size = v.len();
|
||||
|
||||
@ -995,7 +1162,7 @@ mod tests {
|
||||
output.read_to_vec::<T>(v.len())
|
||||
}
|
||||
|
||||
fn run_affine_strided<T: Clone>(
|
||||
fn _run_affine_strided<T: Clone>(
|
||||
v: &[T],
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
@ -1008,9 +1175,7 @@ mod tests {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
|
||||
let size = v.len();
|
||||
let output = new_buffer(&device, v);
|
||||
|
||||
call_affine_strided(
|
||||
&device,
|
||||
@ -1106,7 +1271,7 @@ mod tests {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
let dst_el = ids.len() * left_size * right_size;
|
||||
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
@ -1216,7 +1381,7 @@ mod tests {
|
||||
let input = new_buffer(&device, v);
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let mut output =
|
||||
let output =
|
||||
device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_reduce_contiguous(
|
||||
&device,
|
||||
@ -1226,7 +1391,7 @@ mod tests {
|
||||
v.len(),
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1246,7 +1411,7 @@ mod tests {
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let mut output = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
call_last_softmax(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -1342,7 +1507,7 @@ mod tests {
|
||||
options,
|
||||
);
|
||||
|
||||
let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_where_cond_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -1385,4 +1550,50 @@ mod tests {
|
||||
);
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_gemm() {
|
||||
let b = 2;
|
||||
let m = 3;
|
||||
let n = 2;
|
||||
let k = 4;
|
||||
|
||||
|
||||
let left: Vec<_> = (0..b*m*k).map(|f| f as f32).collect();
|
||||
let right: Vec<_> = (0..b*k*n).map(|f| f as f32).collect();
|
||||
let out: Vec<_> = (0..b*m*n).map(|f| f as f32).collect();
|
||||
|
||||
let dims = 3;
|
||||
let left_shape= vec![b, m, k];
|
||||
let right_shape= vec![b, k, n];
|
||||
let out_shape = vec![b, m , n];
|
||||
|
||||
let left_stride = vec![m * k, k, 1];
|
||||
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let left = device.new_buffer_with_data(
|
||||
left.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(left.as_slice()) as u64,
|
||||
options,
|
||||
);
|
||||
let right = device.new_buffer_with_data(
|
||||
right.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(right.as_slice()) as u64,
|
||||
options,
|
||||
);
|
||||
let out = device.new_buffer(
|
||||
(out.len() * std::mem::size_of::<f32>()) as NSUInteger,
|
||||
options,
|
||||
);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
let results = out.read_to_vec::<f32>(b * m * n);
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
}
|
||||
|
BIN
candle-metal-kernels/src/libMetalFlashAttention.metallib
Normal file
BIN
candle-metal-kernels/src/libMetalFlashAttention.metallib
Normal file
Binary file not shown.
Reference in New Issue
Block a user