mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
10 Commits
0.9.0-alph
...
metal-gemm
Author | SHA1 | Date | |
---|---|---|---|
7ec4f64d38 | |||
8712ceb84f | |||
f9b2bb4d46 | |||
aefca7f8e6 | |||
c3b0757995 | |||
560f666d29 | |||
16b49dfd89 | |||
fc0deede31 | |||
e178aacead | |||
bb191c25d5 |
@ -173,8 +173,8 @@ impl Device {
|
||||
|
||||
pub fn supports_bf16(&self) -> bool {
|
||||
match self {
|
||||
Self::Cuda(_) => true,
|
||||
Self::Metal(_) | Self::Cpu => false,
|
||||
Self::Cuda(_) | Self::Metal(_) => true,
|
||||
Self::Cpu => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
921
candle-metal-kernels/src/gemm.metal
Normal file
921
candle-metal-kernels/src/gemm.metal
Normal file
@ -0,0 +1,921 @@
|
||||
//
|
||||
// GEMM.metal
|
||||
// MetalFlashAttention
|
||||
//
|
||||
// Created by Philip Turner on 6/23/23.
|
||||
//
|
||||
#include <metal_stdlib>
|
||||
|
||||
#ifndef __METAL_SIMDGROUP_EVENT
|
||||
#define __METAL_SIMDGROUP_EVENT
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// %struct._simdgroup_event_t = type opaque
|
||||
//
|
||||
struct _simdgroup_event_t;
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_1d(
|
||||
ulong, ulong, threadgroup void *, const device void *, ulong)
|
||||
__asm("air.simdgroup_async_copy_1d.p3i8.p1i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_1d(
|
||||
ulong, ulong, device void *, const threadgroup void *, ulong)
|
||||
__asm("air.simdgroup_async_copy_1d.p1i8.p3i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: argmemonly convergent nounwind
|
||||
// declare %struct._simdgroup_event_t*
|
||||
// @air.simdgroup_async_copy_2d.p3i8.p1i8(
|
||||
// i64, i64,
|
||||
// i8 addrspace(3)* nocapture writeonly, i64, i64, <2 x i64>,
|
||||
// i8 addrspace(1)* nocapture readonly, i64, i64, <2 x i64>,
|
||||
// <2 x i64>, i32)
|
||||
// local_unnamed_addr #4
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_2d(
|
||||
ulong, ulong,
|
||||
threadgroup void *, ulong, ulong, ulong2,
|
||||
const device void *, ulong, ulong, ulong2,
|
||||
long2, int)
|
||||
__asm("air.simdgroup_async_copy_2d.p3i8.p1i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: argmemonly convergent nounwind
|
||||
// declare %struct._simdgroup_event_t*
|
||||
// @air.simdgroup_async_copy_2d.p1i8.p3i8(
|
||||
// i64, i64,
|
||||
// i8 addrspace(1)* nocapture writeonly, i64, i64, <2 x i64>,
|
||||
// i8 addrspace(3)* nocapture readonly, i64, i64, <2 x i64>,
|
||||
// <2 x i64>, i32)
|
||||
// local_unnamed_addr #4
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_2d(
|
||||
ulong, ulong,
|
||||
device void *, ulong, ulong, ulong2,
|
||||
const threadgroup void *, ulong, ulong, ulong2,
|
||||
long2, int)
|
||||
__asm("air.simdgroup_async_copy_2d.p1i8.p3i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: convergent nounwind
|
||||
// declare void
|
||||
// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture)
|
||||
// local_unnamed_addr #3
|
||||
//
|
||||
void __metal_wait_simdgroup_events(
|
||||
int, thread _simdgroup_event_t**)
|
||||
__asm("air.wait_simdgroup_events");
|
||||
|
||||
#pragma METAL internals : enable
|
||||
namespace metal
|
||||
{
|
||||
enum class simdgroup_async_copy_clamp_mode {
|
||||
clamp_to_zero = 0,
|
||||
clamp_to_edge = 1
|
||||
};
|
||||
|
||||
struct simdgroup_event {
|
||||
METAL_FUNC simdgroup_event() thread {}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
threadgroup T *dst,
|
||||
const device T *src,
|
||||
ulong n_elements
|
||||
) thread {
|
||||
event = __metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<threadgroup void *>(dst),
|
||||
reinterpret_cast<const device void *>(src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
device T *dst,
|
||||
const threadgroup T *src,
|
||||
ulong n_elements
|
||||
) thread {
|
||||
event = __metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<device void *>(dst),
|
||||
reinterpret_cast<const threadgroup void *>(src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
threadgroup T *dst,
|
||||
ushort dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const device T *src,
|
||||
uint src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false,
|
||||
simdgroup_async_copy_clamp_mode clamp_mode =
|
||||
simdgroup_async_copy_clamp_mode::clamp_to_zero
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
event = __metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<threadgroup void *>(dst),
|
||||
ushort(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const device void *>(src),
|
||||
uint(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
static_cast<int>(clamp_mode));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
device T *dst,
|
||||
uint dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const threadgroup T *src,
|
||||
ushort src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
event = __metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<device void *>(dst),
|
||||
uint(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const threadgroup void *>(src),
|
||||
ushort(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
0);
|
||||
}
|
||||
|
||||
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
|
||||
__metal_wait_simdgroup_events(
|
||||
count, reinterpret_cast<thread _simdgroup_event_t**>(events));
|
||||
}
|
||||
|
||||
private:
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* }
|
||||
//
|
||||
thread _simdgroup_event_t* event;
|
||||
};
|
||||
} // namespace metal
|
||||
#pragma METAL internals : disable
|
||||
#endif
|
||||
|
||||
// -*- Metal -*-
|
||||
//===-- metal_simdgroup_matrix_storage ------------------------------------===//
|
||||
// Copyright (c) 2023 Philip Turner. See MIT LICENSE
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE
|
||||
#define __METAL_SIMDGROUP_MATRIX_STORAGE
|
||||
|
||||
// Contains C++ symbols accessible to a developer through automatic code
|
||||
// completion in Xcode 14.2. Formatted with the same style as the Metal Standard
|
||||
// Library for consistency with other Metal code.
|
||||
|
||||
#if defined(__HAVE_SIMDGROUP_MATRIX__)
|
||||
#pragma METAL internals : enable
|
||||
namespace metal
|
||||
{
|
||||
template <typename T>
|
||||
struct simdgroup_matrix_storage {
|
||||
typedef vec<T, 64> storage_type;
|
||||
|
||||
storage_type t;
|
||||
|
||||
METAL_FUNC thread vec<T, 2>* thread_elements() thread {
|
||||
return reinterpret_cast<thread vec<T, 2>*>(&t);
|
||||
}
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage() thread = default;
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread {
|
||||
*(this->thread_elements()) = thread_elements;
|
||||
}
|
||||
|
||||
METAL_FUNC static ushort2 offset(ushort thread_index_in_simdgroup) {
|
||||
// https://patents.google.com/patent/US11256518B2
|
||||
ushort lane_id = thread_index_in_simdgroup;
|
||||
ushort quad_id = lane_id / 4;
|
||||
|
||||
constexpr ushort QUADRANT_SPAN_M = 4;
|
||||
constexpr ushort THREADS_PER_QUADRANT = 8;
|
||||
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M;
|
||||
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2);
|
||||
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant;
|
||||
|
||||
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4
|
||||
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2
|
||||
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant;
|
||||
|
||||
return ushort2(N_in_simd, M_in_simd);
|
||||
}
|
||||
|
||||
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
|
||||
} else {
|
||||
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
|
||||
} else {
|
||||
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING: All load and store functions assume the X dimension is divisible by 2.
|
||||
|
||||
METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
*(thread_elements()) = vec<T, 2>(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]);
|
||||
} else {
|
||||
*(thread_elements()) = *reinterpret_cast<const device vec<T, 2>*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
*(thread_elements()) = vec<T, 2>(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]);
|
||||
} else {
|
||||
*(thread_elements()) = *reinterpret_cast<const threadgroup vec<T, 2>*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
|
||||
} else {
|
||||
thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
|
||||
} else {
|
||||
thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
|
||||
dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
|
||||
} else {
|
||||
*reinterpret_cast<device vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
|
||||
} else {
|
||||
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0];
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
|
||||
} else {
|
||||
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1];
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0];
|
||||
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1];
|
||||
} else {
|
||||
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V>
|
||||
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) {
|
||||
if (!accumulate) {
|
||||
*(thread_elements()) = vec<T, 2>(0);
|
||||
}
|
||||
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
|
||||
}
|
||||
|
||||
// 'bfloat' is 'float' with the lower 16 bits set to garbage (BF15).
|
||||
|
||||
METAL_FUNC thread ushort4* thread_elements_bfloat() thread {
|
||||
thread float2* elements = thread_elements();
|
||||
return reinterpret_cast<thread ushort4*>(elements);
|
||||
}
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage<float> unpack_bfloat() thread {
|
||||
ushort4 output;
|
||||
thread ushort2& elements = thread_elements();
|
||||
output.y = elements[0];
|
||||
output.w = elements[1];
|
||||
return simdgroup_matrix_storage(as_type<float2>(output));
|
||||
}
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage<ushort> pack_bfloat() thread {
|
||||
thread ushort4* elements = thread_elements_bfloat();
|
||||
return simdgroup_matrix_storage(ushort2(elements->y, elements->w));
|
||||
}
|
||||
|
||||
METAL_FUNC void load_bfloat(const threadgroup ushort *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
thread_elements_bfloat()->y = src[matrix_origin.x * elements_per_row + matrix_origin.y];
|
||||
thread_elements_bfloat()->w = src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y];
|
||||
} else {
|
||||
thread_elements_bfloat()->zw = *reinterpret_cast<const threadgroup ushort2*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
|
||||
thread_elements_bfloat()->y = thread_elements_bfloat()->z;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_bfloat(threadgroup ushort *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).y;
|
||||
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).w;
|
||||
} else {
|
||||
*(thread_elements_bfloat()).z = *(thread_elements_bfloat()).y;
|
||||
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements_bfloat()).zw;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace metal
|
||||
#pragma METAL internals : disable
|
||||
#endif
|
||||
|
||||
#endif // __METAL_SIMDGROUP_MATRIX_STORAGE
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// MARK: - Function Constants
|
||||
|
||||
// Dimensions of each matrix.
|
||||
constant uint M [[function_constant(0)]];
|
||||
constant uint N [[function_constant(1)]];
|
||||
constant uint K [[function_constant(2)]];
|
||||
|
||||
// Whether each matrix is transposed.
|
||||
constant bool A_trans [[function_constant(10)]];
|
||||
constant bool B_trans [[function_constant(11)]];
|
||||
constant bool D_trans [[function_constant(13)]];
|
||||
constant uint A_leading_dim = A_trans ? M : K;
|
||||
constant uint B_leading_dim = B_trans ? K : N;
|
||||
|
||||
// Alpha and beta constants from BLAS.
|
||||
constant float alpha [[function_constant(20)]];
|
||||
constant float beta [[function_constant(21)]];
|
||||
|
||||
constant bool batched [[function_constant(100)]];
|
||||
constant bool fused_activation [[function_constant(101)]];
|
||||
constant bool fused_bias [[function_constant(50001)]]; // 102
|
||||
constant bool use_bias = is_function_constant_defined(fused_bias) ? fused_bias : false;
|
||||
constant bool use_activation_function = fused_activation && !fused_bias;
|
||||
constant bool use_activation = use_bias || use_activation_function;
|
||||
constant bool batched_activation_function = batched && use_activation_function;
|
||||
|
||||
constant ushort M_simd [[function_constant(200)]];
|
||||
constant ushort N_simd [[function_constant(201)]];
|
||||
constant ushort K_simd [[function_constant(202)]];
|
||||
|
||||
// Elide work on the edge when matrix dimension < SRAM block dimension.
|
||||
constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd);
|
||||
constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd);
|
||||
constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd;
|
||||
constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd;
|
||||
|
||||
constant ushort M_splits [[function_constant(210)]];
|
||||
constant ushort N_splits [[function_constant(211)]];
|
||||
|
||||
constant ushort M_group = M_simd * M_splits;
|
||||
constant ushort N_group = N_simd * N_splits;
|
||||
constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd);
|
||||
constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group);
|
||||
|
||||
// There is no padding for M reads/writes.
|
||||
// There is no padding for N reads/writes.
|
||||
constant ushort K_simd_unpadded = (K % K_simd == 0) ? K_simd : (K % K_simd);
|
||||
constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8;
|
||||
|
||||
constant ushort A_sram_length = (M_simd / 8) * 1;
|
||||
constant ushort B_sram_length = 1 * (N_simd / 8);
|
||||
constant ushort A_block_length = M_group * K_simd;
|
||||
|
||||
// Threadgroup block must fit entire C accumulator and partial sums.
|
||||
constant ushort A_sram_offset = 0;
|
||||
constant ushort B_sram_offset = A_sram_offset + A_sram_length;
|
||||
constant ushort C_sram_offset = B_sram_offset + B_sram_length;
|
||||
constant ushort A_block_offset = 0;
|
||||
constant ushort B_block_offset = A_block_offset + A_block_length;
|
||||
|
||||
// MARK: - Utilities
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC thread simdgroup_matrix_storage<T>* A_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
|
||||
// A_sram[M_simd][8]
|
||||
return sram + A_sram_offset + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC thread simdgroup_matrix_storage<T>* B_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
|
||||
// A_sram[8][N_simd]
|
||||
return sram + B_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC thread simdgroup_matrix_storage<T>* C_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
|
||||
// C_sram[M_simd][N_simd]
|
||||
return sram + C_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void prefetch(threadgroup T *A_block, device T *A,
|
||||
ushort2 A_tile_src, uint2 A_offset,
|
||||
threadgroup T *B_block, device T *B,
|
||||
ushort2 B_tile_src, uint2 B_offset, uint k)
|
||||
{
|
||||
A_tile_src.x = min(uint(K_simd), K - k);
|
||||
B_tile_src.y = min(uint(K_simd), K - k);
|
||||
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||
auto B_src = simdgroup_matrix_storage<T>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||
|
||||
// Rounded-up ceiling for the threadgroup block.
|
||||
const uint K_edge_floor = K - K_simd_unpadded;
|
||||
const uint K_edge_ceil = K_edge_floor + K_simd_padded;
|
||||
ushort K_padded;
|
||||
if (K_edge_floor == K_simd) {
|
||||
K_padded = K_simd;
|
||||
} else {
|
||||
K_padded = min(uint(K_simd), K_edge_ceil - k);
|
||||
}
|
||||
ushort2 A_tile_dst(K_padded, A_tile_src.y);
|
||||
ushort2 B_tile_dst(B_tile_src.x, K_padded);
|
||||
|
||||
simdgroup_event events[2];
|
||||
events[0].async_copy(A_block, A_block_leading_dim, A_tile_dst, A_src, A_leading_dim, A_tile_src, A_trans);
|
||||
events[1].async_copy(B_block, B_block_leading_dim, B_tile_dst, B_src, B_leading_dim, B_tile_src, B_trans);
|
||||
simdgroup_event::wait(2, events);
|
||||
}
|
||||
|
||||
// One iteration of the MACC loop, effectively k=8 iterations.
|
||||
template <typename T>
|
||||
METAL_FUNC void multiply_accumulate(thread simdgroup_matrix_storage<T> *sram,
|
||||
const threadgroup T *A_block,
|
||||
const threadgroup T *B_block,
|
||||
bool accumulate = true)
|
||||
{
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
ushort2 origin(0, m);
|
||||
A_sram(sram, origin)->load(A_block, A_block_leading_dim, origin, A_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, 0);
|
||||
B_sram(sram, origin)->load(B_block, B_block_leading_dim, origin, B_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
auto A = A_sram(sram, ushort2(0, m));
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
auto B = B_sram(sram, ushort2(n, 0));
|
||||
auto C = C_sram(sram, ushort2(n, m));
|
||||
C->multiply(*A, *B, accumulate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void partial_store(thread simdgroup_matrix_storage<T> *sram,
|
||||
threadgroup T *C_block, bool is_k_summation)
|
||||
{
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
if (is_k_summation) {
|
||||
C_sram(sram, origin)->store(C_block, N_simd, origin);
|
||||
} else {
|
||||
C_sram(sram, origin)->store(C_block, N_group, origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void partial_accumulate(thread simdgroup_matrix_storage<T> *sram,
|
||||
threadgroup T *C_block, bool is_k_summation)
|
||||
{
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
auto B = B_sram(sram, ushort2(n, 0));
|
||||
if (is_k_summation) {
|
||||
B->load(C_block, N_simd, origin);
|
||||
} else {
|
||||
B->load(C_block, N_group, origin);
|
||||
}
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
auto B = B_sram(sram, ushort2(n, 0));
|
||||
auto C = C_sram(sram, origin);
|
||||
if (is_k_summation) {
|
||||
C->thread_elements()[0] += B->thread_elements()[0];
|
||||
} else {
|
||||
float2 C_old = float2(B->thread_elements()[0]);
|
||||
float2 C_new = float2(C->thread_elements()[0]);
|
||||
C->thread_elements()[0] = vec<T, 2>(fast::fma(C_old, beta, C_new));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_access_accumulator(threadgroup T *C_block, device T *C,
|
||||
uint2 C_offset, bool is_store)
|
||||
{
|
||||
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
|
||||
min(uint(M_group), M - C_offset.y));
|
||||
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, C_offset);
|
||||
|
||||
if (is_store) {
|
||||
simdgroup_event event;
|
||||
event.async_copy(C_src, N, C_tile, C_block, N_group, C_tile);
|
||||
simdgroup_event::wait(1, &event);
|
||||
} else {
|
||||
simdgroup_event event;
|
||||
event.async_copy(C_block, N_group, C_tile, C_src, N, C_tile);
|
||||
simdgroup_event::wait(1, &event);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage<T> *sram,
|
||||
device T *C, bool m_is_edge, bool n_is_edge)
|
||||
{
|
||||
const ushort m_start = (m_is_edge) ? M_modulo : 0;
|
||||
const ushort n_start = (n_is_edge) ? N_modulo : 0;
|
||||
const ushort m_end = (m_is_edge) ? M_simd : M_modulo;
|
||||
const ushort n_end = (n_is_edge) ? N_simd : N_modulo;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = m_start; m < m_end; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = n_start; n < n_end; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
C_sram(sram, origin)->store(C, N, origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct activation_functor {
|
||||
using function = void(threadgroup T *C,
|
||||
device void *D,
|
||||
uint grid_index_in_batch,
|
||||
uint2 matrix_origin,
|
||||
ushort2 tile_dimensions,
|
||||
ushort lane_id);
|
||||
|
||||
typedef visible_function_table<function> function_table;
|
||||
};
|
||||
|
||||
// MARK: - Kernels
|
||||
|
||||
template <typename T>
|
||||
void _gemm_impl(device T *A [[buffer(0)]],
|
||||
device T *B [[buffer(1)]],
|
||||
device T *C [[buffer(2)]],
|
||||
device void *D [[buffer(3), function_constant(use_activation)]],
|
||||
|
||||
threadgroup T *threadgroup_block [[threadgroup(0)]],
|
||||
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
typename activation_functor<T>::function_table table [[buffer(11), function_constant(use_activation_function)]],
|
||||
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]])
|
||||
{
|
||||
if (batched) {
|
||||
// TODO: Re-compute every inner loop iteration for FP64 accumulate.
|
||||
ulong3 offsets = matrix_offsets[0].xyz * gid.z;
|
||||
A = (device T*)((device uchar*)A + offsets[0]);
|
||||
B = (device T*)((device uchar*)B + offsets[1]);
|
||||
C = (device T*)((device uchar*)C + offsets[2]);
|
||||
}
|
||||
|
||||
simdgroup_matrix_storage<T> sram[1024];
|
||||
auto A_block = threadgroup_block + A_block_offset;
|
||||
auto B_block = threadgroup_block + B_block_offset;
|
||||
ushort2 sid(sidx % N_splits, sidx / N_splits);
|
||||
ushort2 offset_in_simd = simdgroup_matrix_storage<T>::offset(lane_id);
|
||||
|
||||
uint2 A_offset(0, gid.y * M_group);
|
||||
uint2 B_offset(gid.x * N_group, 0);
|
||||
{
|
||||
uint C_base_offset_x = B_offset.x + sid.x * N_simd;
|
||||
uint C_base_offset_y = A_offset.y + sid.y * M_simd;
|
||||
if (C_base_offset_x >= N || C_base_offset_y >= M) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ushort2 offset_in_group(sid.x * N_simd + offset_in_simd.x,
|
||||
sid.y * M_simd + offset_in_simd.y);
|
||||
|
||||
if (use_bias) {
|
||||
if (sidx == 0) {
|
||||
auto bias = (device T*)D;
|
||||
if (batched) {
|
||||
ulong offset = matrix_offsets[gid.z].w;
|
||||
bias = (device T*)((device uchar*)bias + offset);
|
||||
}
|
||||
|
||||
ushort bias_elements;
|
||||
if (is_function_constant_defined(D_trans) && D_trans) {
|
||||
bias += A_offset.y;
|
||||
bias_elements = min(uint(M_group), M - A_offset.y);
|
||||
} else {
|
||||
bias += B_offset.x;
|
||||
bias_elements = min(uint(N_group), N - B_offset.x);
|
||||
}
|
||||
|
||||
simdgroup_event event;
|
||||
event.async_copy(threadgroup_block, bias, bias_elements);
|
||||
simdgroup_event::wait(1, &event);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (is_function_constant_defined(D_trans) && D_trans) {
|
||||
auto bias = threadgroup_block + offset_in_group.y;
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
auto D = bias[m];
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
auto C = C_sram(sram, ushort2(n, m));
|
||||
*(C->thread_elements()) = D;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto bias = threadgroup_block + offset_in_group.x;
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
auto D = *(threadgroup vec<T, 2>*)(bias + n);
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
auto C = C_sram(sram, ushort2(n, m));
|
||||
*(C->thread_elements()) = D;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
ushort2 A_tile_src;
|
||||
ushort2 B_tile_src;
|
||||
if (sidx == 0) {
|
||||
A_tile_src.y = min(uint(M_group), M - A_offset.y);
|
||||
B_tile_src.x = min(uint(N_group), N - B_offset.x);
|
||||
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, 0);
|
||||
}
|
||||
|
||||
if (K > K_simd && !use_bias) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
*C_sram(sram, ushort2(n, m)) = simdgroup_matrix_storage<T>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (uint K_floor = 0; K_floor < K; K_floor += K_simd) {
|
||||
ushort2 A_block_offset(offset_in_simd.x, offset_in_group.y);
|
||||
ushort2 B_block_offset(offset_in_group.x, offset_in_simd.y);
|
||||
auto A_block_src = simdgroup_matrix_storage<T>::apply_offset(A_block, A_block_leading_dim, A_block_offset, A_trans);
|
||||
auto B_block_src = simdgroup_matrix_storage<T>::apply_offset(B_block, B_block_leading_dim, B_block_offset, B_trans);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort k = 0; k < K_simd_padded; k += 8) {
|
||||
bool accumulate = use_bias || !(K <= K_simd && k == 0);
|
||||
multiply_accumulate(sram, A_block_src, B_block_src, accumulate);
|
||||
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
|
||||
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
|
||||
}
|
||||
|
||||
if (K_floor + K_simd < K) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort k = K_simd_padded; k < K_simd; k += 8) {
|
||||
multiply_accumulate(sram, A_block_src, B_block_src);
|
||||
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
|
||||
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sidx == 0) {
|
||||
uint K_next = K_floor + K_simd;
|
||||
A_offset.x = K_next;
|
||||
B_offset.y = K_next;
|
||||
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, K_next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (alpha != 1) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int n = 0; n < N_padded; n += 8) {
|
||||
C_sram(sram, ushort2(n, m))->thread_elements()[0] *= T(alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint2 C_offset(B_offset.x, A_offset.y);
|
||||
ushort2 C_block_offset = offset_in_group.xy;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (beta != 0) {
|
||||
if (sidx == 0) {
|
||||
async_access_accumulator(threadgroup_block, C, C_offset, false);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
|
||||
partial_accumulate(sram, C_block, false);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
if (use_activation_function) {
|
||||
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
|
||||
partial_store(sram, C_block, false);
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
uint grid_index_in_batch = (batched ? gid.z : 0);
|
||||
uint2 matrix_origin = C_offset + uint2(C_block_offset);
|
||||
matrix_origin &= ~7;
|
||||
ushort2 tile_dimensions(min(uint(N_group), N - matrix_origin.x),
|
||||
min(uint(M_group), M - matrix_origin.y));
|
||||
uint function_index = 0;
|
||||
if (batched_activation_function) {
|
||||
function_index = activation_function_offsets[gid.z];
|
||||
}
|
||||
table[function_index](C_block, D, grid_index_in_batch, matrix_origin, tile_dimensions, lane_id);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sidx == 0) {
|
||||
async_access_accumulator(threadgroup_block, C, C_offset, true);
|
||||
}
|
||||
return;
|
||||
} else if ((M % 8 != 0) || (N % 8 != 0)) {
|
||||
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
|
||||
partial_store(sram, C_block, false);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sidx == 0) {
|
||||
async_access_accumulator(threadgroup_block, C, C_offset, true);
|
||||
}
|
||||
} else {
|
||||
uint2 matrix_origin = C_offset + uint2(C_block_offset);
|
||||
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, matrix_origin);
|
||||
store_accumulator(sram, C_src, false, false);
|
||||
|
||||
const uint M_edge_floor = M - M % M_simd;
|
||||
const uint N_edge_floor = N - N % N_simd;
|
||||
if (matrix_origin.y < M_edge_floor) {
|
||||
store_accumulator(sram, C_src, true, false);
|
||||
}
|
||||
if (matrix_origin.x < N_edge_floor) {
|
||||
store_accumulator(sram, C_src, false, true);
|
||||
if (matrix_origin.y < M_edge_floor) {
|
||||
store_accumulator(sram, C_src, true, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void hgemm(device half *A [[buffer(0)]],
|
||||
device half *B [[buffer(1)]],
|
||||
device half *C [[buffer(2)]],
|
||||
device void *D [[buffer(3), function_constant(use_activation)]],
|
||||
|
||||
threadgroup half *threadgroup_block [[threadgroup(0)]],
|
||||
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
typename activation_functor<half>::function_table table [[buffer(11), function_constant(use_activation_function)]],
|
||||
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]])
|
||||
{
|
||||
_gemm_impl<half>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
|
||||
}
|
||||
|
||||
kernel void sgemm(device float *A [[buffer(0)]],
|
||||
device float *B [[buffer(1)]],
|
||||
device float *C [[buffer(2)]],
|
||||
device void *D [[buffer(3), function_constant(use_activation)]],
|
||||
|
||||
threadgroup float *threadgroup_block [[threadgroup(0)]],
|
||||
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
typename activation_functor<float>::function_table table [[buffer(11), function_constant(use_activation_function)]],
|
||||
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]])
|
||||
{
|
||||
_gemm_impl<float>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
|
||||
}
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
kernel void bgemm(
|
||||
device bfloat *A [[buffer(0)]],
|
||||
device bfloat *B [[buffer(1)]],
|
||||
device bfloat *C [[buffer(2)]],
|
||||
device void *D [[buffer(3), function_constant(use_activation)]],
|
||||
|
||||
threadgroup bfloat *threadgroup_block [[threadgroup(0)]],
|
||||
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
typename activation_functor<bfloat>::function_table table [[buffer(11), function_constant(use_activation_function)]],
|
||||
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]])
|
||||
{
|
||||
_gemm_impl<bfloat>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
|
||||
}
|
||||
#endif
|
@ -20,7 +20,8 @@ const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
// const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const GEMM: &str = include_str!("gemm.metal");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
|
||||
@ -201,7 +202,7 @@ impl Kernels {
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Sort => SORT,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
Source::Mfa => GEMM,
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,22 +217,14 @@ impl Kernels {
|
||||
if let Some(lib) = libraries.get(&source) {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let lib = match source {
|
||||
Source::Mfa => {
|
||||
let source_data = MFA;
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
}
|
||||
};
|
||||
let compile_options = CompileOptions::new();
|
||||
compile_options.set_fast_math_enabled(true);
|
||||
|
||||
let source_content = self.get_library_source(source);
|
||||
let lib = device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
|
||||
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
}
|
||||
|
Binary file not shown.
@ -1162,27 +1162,25 @@ fn gemm() {
|
||||
);
|
||||
|
||||
// bgemm sanity test
|
||||
if false {
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let results = run_gemm(
|
||||
"bgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
}
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let results = run_gemm(
|
||||
"bgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
|
||||
// hgemm sanity test
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
|
Reference in New Issue
Block a user