mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
batched gemm work
This commit is contained in:
@ -2,11 +2,11 @@ mod benchmarks;
|
|||||||
|
|
||||||
use criterion::criterion_main;
|
use criterion::criterion_main;
|
||||||
criterion_main!(
|
criterion_main!(
|
||||||
benchmarks::affine::benches,
|
//benchmarks::affine::benches,
|
||||||
benchmarks::matmul::benches,
|
benchmarks::matmul::benches,
|
||||||
benchmarks::random::benches,
|
//benchmarks::random::benches,
|
||||||
benchmarks::where_cond::benches,
|
//benchmarks::where_cond::benches,
|
||||||
benchmarks::conv_transpose2d::benches,
|
//benchmarks::conv_transpose2d::benches,
|
||||||
benchmarks::qmatmul::benches,
|
//benchmarks::qmatmul::benches,
|
||||||
benchmarks::unary::benches
|
//benchmarks::unary::benches
|
||||||
);
|
);
|
||||||
|
@ -44,22 +44,13 @@ constant uint K [[function_constant(2)]];
|
|||||||
constant bool A_trans [[function_constant(10)]];
|
constant bool A_trans [[function_constant(10)]];
|
||||||
constant bool B_trans [[function_constant(11)]];
|
constant bool B_trans [[function_constant(11)]];
|
||||||
|
|
||||||
// Define the memory layout of the matrix block.
|
|
||||||
constant ushort M_group [[function_constant(200)]];
|
|
||||||
constant ushort N_group [[function_constant(201)]];
|
|
||||||
constant ushort K_group [[function_constant(202)]];
|
|
||||||
|
|
||||||
constant bool prefer_async_copy [[function_constant(206)]];
|
constant bool prefer_async_copy [[function_constant(206)]];
|
||||||
constant bool ideal_grouping [[function_constant(207)]];
|
constant bool ideal_grouping [[function_constant(207)]];
|
||||||
|
|
||||||
|
constant bool batched [[function_constant(100)]];
|
||||||
|
|
||||||
constant ushort A_leading_dim = A_trans ? M : K;
|
constant ushort A_leading_dim = A_trans ? M : K;
|
||||||
constant ushort B_leading_dim = B_trans ? K : N;
|
constant ushort B_leading_dim = B_trans ? K : N;
|
||||||
constant ushort A_leading_block_dim = A_trans ? M_group : K_group;
|
|
||||||
constant ushort B_leading_block_dim = B_trans ? K_group : N_group;
|
|
||||||
|
|
||||||
// Thresholds that mark the matrix edge.
|
|
||||||
constant uint M_edge = M - (M % M_group);
|
|
||||||
constant uint N_edge = N - (N % N_group);
|
|
||||||
|
|
||||||
// The layout of threads within a SIMD matrix.
|
// The layout of threads within a SIMD matrix.
|
||||||
//
|
//
|
||||||
@ -123,28 +114,28 @@ METAL_FUNC void multiply_accumulate(
|
|||||||
thread simdgroup_matrix_storage<U> *C_sram,
|
thread simdgroup_matrix_storage<U> *C_sram,
|
||||||
ushort k
|
ushort k
|
||||||
) {
|
) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (ushort m = 0; m < M_register; m += 8) {
|
for (ushort m = 0; m < M_register; m += 8) {
|
||||||
ushort2 origin(0, m);
|
ushort2 origin(0, m);
|
||||||
auto A = get_sram(A_sram, 8, origin);
|
auto A = get_sram(A_sram, 8, origin);
|
||||||
A->load(A_src, A_leading_dim, ushort2(k, m), A_trans);
|
A->load(A_src, A_leading_dim, ushort2(k, m), A_trans);
|
||||||
}
|
}
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (ushort n = 0; n < N_register; n += 8) {
|
for (ushort n = 0; n < N_register; n += 8) {
|
||||||
ushort2 origin(n, 0);
|
ushort2 origin(n, 0);
|
||||||
auto B = get_sram(B_sram, N_register, origin);
|
auto B = get_sram(B_sram, N_register, origin);
|
||||||
B->load(B_src, B_leading_dim, ushort2(n, k), B_trans);
|
B->load(B_src, B_leading_dim, ushort2(n, k), B_trans);
|
||||||
}
|
}
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (ushort m = 0; m < M_register; m += 8) {
|
for (ushort m = 0; m < M_register; m += 8) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (ushort n = 0; n < N_register; n += 8) {
|
for (ushort n = 0; n < N_register; n += 8) {
|
||||||
auto A = get_sram(A_sram, 8, ushort2(0, m));
|
auto A = get_sram(A_sram, 8, ushort2(0, m));
|
||||||
auto B = get_sram(B_sram, N_register, ushort2(n, 0));
|
auto B = get_sram(B_sram, N_register, ushort2(n, 0));
|
||||||
auto C = get_sram(C_sram, N_register, ushort2(n, m));
|
auto C = get_sram(C_sram, N_register, ushort2(n, m));
|
||||||
C->multiply(*A, *B);
|
C->multiply(*A, *B);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// One multiply-accumulate loop iteration, or 8 dot products.
|
// One multiply-accumulate loop iteration, or 8 dot products.
|
||||||
@ -162,28 +153,28 @@ METAL_FUNC void multiply_accumulate(
|
|||||||
thread simdgroup_matrix_storage<U> *C_sram,
|
thread simdgroup_matrix_storage<U> *C_sram,
|
||||||
ushort k
|
ushort k
|
||||||
) {
|
) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (ushort m = 0; m < M_register; m += 8) {
|
for (ushort m = 0; m < M_register; m += 8) {
|
||||||
ushort2 origin(0, m);
|
ushort2 origin(0, m);
|
||||||
auto A = get_sram(A_sram, 8, origin);
|
auto A = get_sram(A_sram, 8, origin);
|
||||||
A->load(A_src, A_leading_dim, ushort2(k, m), A_trans);
|
A->load(A_src, A_leading_dim, ushort2(k, m), A_trans);
|
||||||
}
|
}
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (ushort n = 0; n < N_register; n += 8) {
|
for (ushort n = 0; n < N_register; n += 8) {
|
||||||
ushort2 origin(n, 0);
|
ushort2 origin(n, 0);
|
||||||
auto B = get_sram(B_sram, N_register, origin);
|
auto B = get_sram(B_sram, N_register, origin);
|
||||||
B->load(B_src, B_leading_dim, ushort2(n, k), B_trans);
|
B->load(B_src, B_leading_dim, ushort2(n, k), B_trans);
|
||||||
}
|
}
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (ushort m = 0; m < M_register; m += 8) {
|
for (ushort m = 0; m < M_register; m += 8) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (ushort n = 0; n < N_register; n += 8) {
|
for (ushort n = 0; n < N_register; n += 8) {
|
||||||
auto A = get_sram(A_sram, 8, ushort2(0, m));
|
auto A = get_sram(A_sram, 8, ushort2(0, m));
|
||||||
auto B = get_sram(B_sram, N_register, ushort2(n, 0));
|
auto B = get_sram(B_sram, N_register, ushort2(n, 0));
|
||||||
auto C = get_sram(C_sram, N_register, ushort2(n, m));
|
auto C = get_sram(C_sram, N_register, ushort2(n, m));
|
||||||
C->multiply(*A, *B);
|
C->multiply(*A, *B);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Metal function arguments.
|
// Metal function arguments.
|
||||||
@ -191,19 +182,19 @@ METAL_FUNC void multiply_accumulate(
|
|||||||
// A: the left-hand side matrix
|
// A: the left-hand side matrix
|
||||||
// - dimensions: M x K
|
// - dimensions: M x K
|
||||||
// K x M (transposed)
|
// K x M (transposed)
|
||||||
// - memory precision: memA
|
// - memory precision: T
|
||||||
// - register precision: regA
|
// - register precision: T
|
||||||
//
|
//
|
||||||
// B: the right-hand side matrix
|
// B: the right-hand side matrix
|
||||||
// - dimensions: K x N
|
// - dimensions: K x N
|
||||||
// N x K (transposed)
|
// N x K (transposed)
|
||||||
// - memory precision: memB
|
// - memory precision: U
|
||||||
// - register precision: regB
|
// - register precision: U
|
||||||
//
|
//
|
||||||
// C: the output matrix, alternatively the dot product accumulator
|
// C: the output matrix, alternatively the dot product accumulator
|
||||||
// - dimensions: M x N
|
// - dimensions: M x N
|
||||||
// - memory precision: memC
|
// - memory precision: V
|
||||||
// - register precision: regC
|
// - register precision: V
|
||||||
//
|
//
|
||||||
// threadgroup_block: the chunk of threadgroup memory allocated at runtime
|
// threadgroup_block: the chunk of threadgroup memory allocated at runtime
|
||||||
// - ideally 10 KB or less
|
// - ideally 10 KB or less
|
||||||
@ -211,28 +202,35 @@ METAL_FUNC void multiply_accumulate(
|
|||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename U = T,
|
typename U = T,
|
||||||
ushort M_block_dim,
|
typename V = U,
|
||||||
ushort N_block_dim,
|
ushort M_group,
|
||||||
ushort K_block_dim,
|
ushort N_group,
|
||||||
ushort M_split,
|
ushort K_group,
|
||||||
ushort N_split
|
ushort M_splits,
|
||||||
|
ushort N_splits,
|
||||||
|
ushort M_register = M_group / M_splits,
|
||||||
|
ushort N_register = N_group / N_splits
|
||||||
>
|
>
|
||||||
void gemm_impl(
|
void gemm_impl(
|
||||||
device T *A [[buffer(0)]],
|
device T *A [[buffer(0)]],
|
||||||
device U *B [[buffer(1)]],
|
device U *B [[buffer(1)]],
|
||||||
device U *C [[buffer(2)]],
|
device V *C [[buffer(2)]],
|
||||||
|
|
||||||
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
||||||
|
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||||
|
|
||||||
uint3 gid [[threadgroup_position_in_grid]],
|
uint3 gid [[threadgroup_position_in_grid]],
|
||||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||||
ushort lane_id [[thread_index_in_simdgroup]]
|
ushort lane_id [[thread_index_in_simdgroup]]
|
||||||
) {
|
) {
|
||||||
constexpr ushort M_register = M_block_dim / M_split;
|
const ushort A_leading_block_dim = A_trans ? M_group : K_group;
|
||||||
constexpr ushort N_register = N_block_dim / N_split;
|
const ushort B_leading_block_dim = B_trans ? K_group : N_group;
|
||||||
constexpr ushort threadgroup_size = 32 * M_split * N_split;
|
|
||||||
|
|
||||||
const ushort iteration_start = prefer_async_copy ? 0 : (K - (K % K_group));
|
// Thresholds that mark the matrix edge.
|
||||||
|
const uint M_edge = M - (M % M_group);
|
||||||
|
const uint N_edge = N - (N % N_group);
|
||||||
|
|
||||||
|
const ushort async_iter_start = prefer_async_copy ? 0 : (K - (K % K_group));
|
||||||
|
|
||||||
// Find the number of elements in the final block. If the matrix
|
// Find the number of elements in the final block. If the matrix
|
||||||
// dimensions are perfectly divisibly by block dimensions, we don't want
|
// dimensions are perfectly divisibly by block dimensions, we don't want
|
||||||
@ -249,9 +247,16 @@ void gemm_impl(
|
|||||||
const ushort M_shift = (M < M_group) ? 0 : M_register - M_remainder;
|
const ushort M_shift = (M < M_group) ? 0 : M_register - M_remainder;
|
||||||
const ushort N_shift = (N < N_group) ? 0 : N_register - N_remainder;
|
const ushort N_shift = (N < N_group) ? 0 : N_register - N_remainder;
|
||||||
|
|
||||||
|
if (batched) {
|
||||||
|
ulong3 offsets = matrix_offsets[0].xyz * gid.z;
|
||||||
|
A = (device T*)((device uchar*)A + offsets[0]);
|
||||||
|
B = (device U*)((device uchar*)B + offsets[1]);
|
||||||
|
C = (device V*)((device uchar*)C + offsets[2]);
|
||||||
|
}
|
||||||
|
|
||||||
auto A_block = (threadgroup T*)(threadgroup_block);
|
auto A_block = (threadgroup T*)(threadgroup_block);
|
||||||
auto B_block = (threadgroup U*)(threadgroup_block + (M*K));
|
auto B_block = (threadgroup U*)(threadgroup_block + (M * K));
|
||||||
ushort2 sid(sidx % N_split, sidx / N_split);
|
ushort2 sid(sidx % N_splits, sidx / N_splits);
|
||||||
ushort2 morton_offset = morton_order(lane_id);
|
ushort2 morton_offset = morton_order(lane_id);
|
||||||
|
|
||||||
// Return early if the SIMD is out of bounds.
|
// Return early if the SIMD is out of bounds.
|
||||||
@ -266,8 +271,8 @@ void gemm_impl(
|
|||||||
N_offset + sid.x * N_register >= N) {
|
N_offset + sid.x * N_register >= N) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ushort2 offset_in_group(sid.x * M_register + morton_offset.x,
|
ushort2 offset_in_group(sid.x * N_register + morton_offset.x,
|
||||||
sid.y * N_register + morton_offset.y);
|
sid.y * M_register + morton_offset.y);
|
||||||
|
|
||||||
// Shift the matrix block within bounds, if possible.
|
// Shift the matrix block within bounds, if possible.
|
||||||
if ((M_shift != 0) && (gid.y * M_group >= M_edge)) {
|
if ((M_shift != 0) && (gid.y * M_group >= M_edge)) {
|
||||||
@ -277,91 +282,98 @@ void gemm_impl(
|
|||||||
N_offset -= N_shift;
|
N_offset -= N_shift;
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_matrix_storage<U> C_sram[(M_register / 8) * (N_register / 8)];
|
simdgroup_matrix_storage<V> C_sram[(M_register / 8) * (N_register / 8)];
|
||||||
|
|
||||||
// Initialize the accumulator.
|
// Initialize the accumulator.
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (ushort m = 0; m < M_register; m += 8) {
|
for (ushort m = 0; m < M_register; m += 8) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (ushort n = 0; n < N_register; n += 8) {
|
for (ushort n = 0; n < N_register; n += 8) {
|
||||||
ushort2 origin(n, m);
|
ushort2 origin(m, n);
|
||||||
auto C = get_sram(C_sram, N_register, origin);
|
auto C = get_sram(C_sram, N_register, origin);
|
||||||
*C = simdgroup_matrix_storage<U>(0);
|
*C = simdgroup_matrix_storage<V>(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform the iterations where async copy is avoided.
|
// Perform the iterations where async copy is avoided.
|
||||||
for (uint k = 0; k < iteration_start; k += 8) {
|
#pragma clang loop unroll(full)
|
||||||
|
for (uint k = 0; k < async_iter_start; k += 8) {
|
||||||
uint2 A_offset(k, M_offset);
|
uint2 A_offset(k, M_offset);
|
||||||
uint2 B_offset(N_offset, k);
|
uint2 B_offset(N_offset, k);
|
||||||
A_offset += uint2(morton_offset.x, offset_in_group.y);
|
A_offset += uint2(morton_offset.x, offset_in_group.y);
|
||||||
B_offset += uint2(offset_in_group.x, morton_offset.y);
|
B_offset += uint2(offset_in_group.x, morton_offset.y);
|
||||||
|
|
||||||
auto A_src = simdgroup_matrix_storage<T>::apply_offset(
|
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||||
A, A_leading_dim, A_offset, A_trans);
|
auto B_src = simdgroup_matrix_storage<U>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||||
auto B_src = simdgroup_matrix_storage<U>::apply_offset(
|
|
||||||
B, N, B_offset, B_trans);
|
|
||||||
|
|
||||||
simdgroup_matrix_storage<T> A_sram[M_register / 8];
|
simdgroup_matrix_storage<T> A_sram[M_register / 8];
|
||||||
simdgroup_matrix_storage<U> B_sram[N_register / 8];
|
simdgroup_matrix_storage<U> B_sram[N_register / 8];
|
||||||
multiply_accumulate<T, U, M_register, N_register>(
|
multiply_accumulate<T, U, M_register, N_register>(A_src, B_src, A_sram, B_sram, C_sram, 0);
|
||||||
A_src, B_src, A_sram, B_sram, C_sram, 0);
|
|
||||||
}
|
}
|
||||||
|
if (!prefer_async_copy) {
|
||||||
// Perform the iterations where async copy is used.
|
#pragma clang loop unroll(full)
|
||||||
for (uint k = iteration_start; k < K; k += K_group) {
|
for (uint k = 0; k < K; k += K_group) {
|
||||||
// Launch an async copy from device to threadgroup memory.
|
|
||||||
if (sidx == 0) {
|
|
||||||
uint2 A_offset(k, M_offset);
|
uint2 A_offset(k, M_offset);
|
||||||
uint2 B_offset(N_offset, k);
|
uint2 B_offset(N_offset, k);
|
||||||
auto A_src = simdgroup_matrix_storage<T>::apply_offset(
|
A_offset += uint2(morton_offset.x, offset_in_group.y);
|
||||||
A, A_leading_dim, A_offset, A_trans);
|
B_offset += uint2(offset_in_group.x, morton_offset.y);
|
||||||
auto B_src = simdgroup_matrix_storage<U>::apply_offset(
|
|
||||||
B, N, B_offset, B_trans);
|
|
||||||
|
|
||||||
ushort M_tile_dimension = min(uint(M_group), M - M_offset);
|
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||||
ushort N_tile_dimension = min(uint(N_group), N - N_offset);
|
auto B_src = simdgroup_matrix_storage<U>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||||
ushort K_tile_dimension = min(uint(K_group), K - k);
|
|
||||||
ushort K_tile_padded = min(uint(K_group), (K + K_remainder_padded - K_remainder) - k);
|
|
||||||
|
|
||||||
ushort2 A_tile_src(K_tile_dimension, M_tile_dimension);
|
simdgroup_matrix_storage<T> A_sram[M_register / 8];
|
||||||
ushort2 B_tile_src(N_tile_dimension, K_tile_dimension);
|
simdgroup_matrix_storage<U> B_sram[N_register / 8];
|
||||||
ushort2 A_tile_dst(K_tile_padded, M_tile_dimension);
|
multiply_accumulate<T, U, M_register, N_register>(A_src, B_src, A_sram, B_sram, C_sram, 0);
|
||||||
ushort2 B_tile_dst(N_tile_dimension, K_tile_padded);
|
|
||||||
|
|
||||||
simdgroup_event events[2];
|
|
||||||
events[0].async_copy(A_block, A_leading_block_dim, A_tile_dst,
|
|
||||||
A_src, A_leading_dim, A_tile_src, A_trans);
|
|
||||||
events[1].async_copy(B_block, B_leading_block_dim, B_tile_dst,
|
|
||||||
B_src, B_leading_dim, B_tile_src, B_trans);
|
|
||||||
simdgroup_event::wait(2, events);
|
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
} else {
|
||||||
|
// Perform the iterations where async copy is used.
|
||||||
ushort2 A_block_offset(morton_offset.x, offset_in_group.y);
|
|
||||||
ushort2 B_block_offset(offset_in_group.x, morton_offset.y);
|
|
||||||
auto A_block_src = simdgroup_matrix_storage<T>::apply_offset(
|
|
||||||
A_block, A_leading_block_dim, A_block_offset, A_trans);
|
|
||||||
auto B_block_src = simdgroup_matrix_storage<U>::apply_offset(
|
|
||||||
B_block, B_leading_block_dim, B_block_offset, B_trans);
|
|
||||||
|
|
||||||
simdgroup_matrix_storage<T> A_sram[(M_register / 8) * (K_block_dim / 8)];
|
|
||||||
simdgroup_matrix_storage<U> B_sram[(K_block_dim / 8) * (N_register / 8)];
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (ushort k = 0; k < K_remainder_padded; k += 8) {
|
for (uint k = async_iter_start; k < K; k += K_group) {
|
||||||
multiply_accumulate<T, U, M_register, N_register>(
|
// Launch an async copy from device to threadgroup memory.
|
||||||
A_block_src, B_block_src, A_sram, B_sram, C_sram, k);
|
if (sidx == 0) {
|
||||||
}
|
uint2 A_offset(k, M_offset);
|
||||||
|
uint2 B_offset(N_offset, k);
|
||||||
|
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||||
|
auto B_src = simdgroup_matrix_storage<U>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||||
|
|
||||||
// Will there be any iterations after this one?
|
ushort M_tile_dimension = min(uint(M_group), M - M_offset);
|
||||||
if (k + K_group < K) {
|
ushort N_tile_dimension = min(uint(N_group), N - N_offset);
|
||||||
// If so, we haven't reached the edge of either input matrix yet.
|
ushort K_tile_dimension = min(uint(K_group), K - k);
|
||||||
#pragma clang loop unroll(full)
|
ushort K_tile_padded = min(uint(K_group), (K + K_remainder_padded - K_remainder) - k);
|
||||||
for (ushort k = K_remainder_padded; k < K_group; k += 8) {
|
|
||||||
multiply_accumulate<T, U, M_register, N_register>(
|
ushort2 A_tile_src(K_tile_dimension, M_tile_dimension);
|
||||||
A_block_src, B_block_src, A_sram, B_sram, C_sram, k);
|
ushort2 B_tile_src(N_tile_dimension, K_tile_dimension);
|
||||||
|
ushort2 A_tile_dst(K_tile_padded, M_tile_dimension);
|
||||||
|
ushort2 B_tile_dst(N_tile_dimension, K_tile_padded);
|
||||||
|
|
||||||
|
simdgroup_event events[2];
|
||||||
|
events[0].async_copy(A_block, A_leading_block_dim, A_tile_dst, A_src, A_leading_dim, A_tile_src, A_trans);
|
||||||
|
events[1].async_copy(B_block, B_leading_block_dim, B_tile_dst, B_src, B_leading_dim, B_tile_src, B_trans);
|
||||||
|
simdgroup_event::wait(2, events);
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
ushort2 A_block_offset(morton_offset.x, offset_in_group.y);
|
||||||
|
ushort2 B_block_offset(offset_in_group.x, morton_offset.y);
|
||||||
|
auto A_block_src = simdgroup_matrix_storage<T>::apply_offset(A_block, A_leading_block_dim, A_block_offset, A_trans);
|
||||||
|
auto B_block_src = simdgroup_matrix_storage<U>::apply_offset(B_block, B_leading_block_dim, B_block_offset, B_trans);
|
||||||
|
|
||||||
|
simdgroup_matrix_storage<T> A_sram[(M_register / 8) * (K_group / 8)];
|
||||||
|
simdgroup_matrix_storage<U> B_sram[(K_group / 8) * (N_register / 8)];
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (ushort k = 0; k < K_remainder_padded; k += 8) {
|
||||||
|
multiply_accumulate<T, U, M_register, N_register>(A_block_src, B_block_src, A_sram, B_sram, C_sram, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Will there be any iterations after this one?
|
||||||
|
if (k + K_group < K) {
|
||||||
|
// If so, we haven't reached the edge of either input matrix yet.
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (ushort k = K_remainder_padded; k < K_group; k += 8) {
|
||||||
|
multiply_accumulate<T, U, M_register, N_register>(A_block_src, B_block_src, A_sram, B_sram, C_sram, k);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -384,9 +396,8 @@ void gemm_impl(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Slow path for when memory must be handled more carefully.
|
// Slow path for when memory must be handled more carefully.
|
||||||
auto C_block = (threadgroup U*)(threadgroup_block);
|
auto C_block = (threadgroup V*)(threadgroup_block);
|
||||||
auto C_block_dst = simdgroup_matrix_storage<U>::apply_offset(
|
auto C_block_dst = simdgroup_matrix_storage<V>::apply_offset(C_block, N_group, offset_in_group);
|
||||||
C_block, N_group, offset_in_group);
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Write the accumulator to threadgroup memory.
|
// Write the accumulator to threadgroup memory.
|
||||||
@ -405,9 +416,8 @@ void gemm_impl(
|
|||||||
if (sidx == 0) {
|
if (sidx == 0) {
|
||||||
uint2 C_offset(gid.x * N_group, gid.y * M_group);
|
uint2 C_offset(gid.x * N_group, gid.y * M_group);
|
||||||
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
|
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
|
||||||
min(uint(M_group), M - C_offset.y));
|
min(uint(M_group), M - C_offset.y));
|
||||||
auto C_dst = simdgroup_matrix_storage<U>::apply_offset(
|
auto C_dst = simdgroup_matrix_storage<V>::apply_offset(C, N, C_offset);
|
||||||
C, N, C_offset);
|
|
||||||
|
|
||||||
// If we shift successfully, the garbage zone moves from the bottom right
|
// If we shift successfully, the garbage zone moves from the bottom right
|
||||||
// to the top left.
|
// to the top left.
|
||||||
@ -419,8 +429,7 @@ void gemm_impl(
|
|||||||
if ((N_shift != 0) && (C_offset.x >= N_edge)) {
|
if ((N_shift != 0) && (C_offset.x >= N_edge)) {
|
||||||
C_block_shift.x = N_shift;
|
C_block_shift.x = N_shift;
|
||||||
}
|
}
|
||||||
C_block = simdgroup_matrix_storage<U>::apply_offset(
|
C_block = simdgroup_matrix_storage<V>::apply_offset(C_block, N_group, C_block_shift);
|
||||||
C_block, N_group, C_block_shift);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_event event;
|
simdgroup_event event;
|
||||||
@ -435,34 +444,19 @@ kernel void hgemm(
|
|||||||
device half *C [[buffer(2)]],
|
device half *C [[buffer(2)]],
|
||||||
|
|
||||||
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
||||||
|
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||||
|
|
||||||
uint3 gid [[threadgroup_position_in_grid]],
|
uint3 gid [[threadgroup_position_in_grid]],
|
||||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||||
ushort lane_id [[thread_index_in_simdgroup]]
|
ushort lane_id [[thread_index_in_simdgroup]]
|
||||||
) {
|
) {
|
||||||
if (ideal_grouping) {
|
if (ideal_grouping) {
|
||||||
gemm_impl<
|
gemm_impl<half, half, half, 32, 32, 32, 1, 1>(
|
||||||
half,
|
A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id
|
||||||
half,
|
|
||||||
32,
|
|
||||||
32,
|
|
||||||
32,
|
|
||||||
1,
|
|
||||||
1
|
|
||||||
>(
|
|
||||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
gemm_impl<
|
gemm_impl<half, half, half, 48, 48, 32, 1, 1>(
|
||||||
half,
|
A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id
|
||||||
half,
|
|
||||||
48,
|
|
||||||
48,
|
|
||||||
32,
|
|
||||||
1,
|
|
||||||
1
|
|
||||||
>(
|
|
||||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -473,40 +467,17 @@ kernel void sgemm(
|
|||||||
device float *C [[buffer(2)]],
|
device float *C [[buffer(2)]],
|
||||||
|
|
||||||
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
||||||
|
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||||
|
|
||||||
uint3 gid [[threadgroup_position_in_grid]],
|
uint3 gid [[threadgroup_position_in_grid]],
|
||||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||||
ushort lane_id [[thread_index_in_simdgroup]]
|
ushort lane_id [[thread_index_in_simdgroup]]
|
||||||
) {
|
) {
|
||||||
|
gemm_impl<float, float, float, 32, 32, 32, 2, 2>(
|
||||||
|
A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id
|
||||||
|
);
|
||||||
|
/*
|
||||||
if (prefer_async_copy) {
|
if (prefer_async_copy) {
|
||||||
// TODO: figure out correct splits
|
|
||||||
if (ideal_grouping) {
|
|
||||||
gemm_impl<
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
32,
|
|
||||||
32,
|
|
||||||
32,
|
|
||||||
2,
|
|
||||||
2
|
|
||||||
>(
|
|
||||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
gemm_impl<
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
48,
|
|
||||||
48,
|
|
||||||
24,
|
|
||||||
2,
|
|
||||||
2
|
|
||||||
>(
|
|
||||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// TODO: figure out correct splits
|
|
||||||
constexpr ushort M_split = 1;
|
constexpr ushort M_split = 1;
|
||||||
constexpr ushort N_split = 1;
|
constexpr ushort N_split = 1;
|
||||||
if (ideal_grouping) {
|
if (ideal_grouping) {
|
||||||
@ -534,5 +505,34 @@ kernel void sgemm(
|
|||||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
A, B, C, threadgroup_block, gid, sidx, lane_id
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
constexpr ushort M_split = 2;
|
||||||
|
constexpr ushort N_split = 2;
|
||||||
|
if (ideal_grouping) {
|
||||||
|
gemm_impl<
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
32,
|
||||||
|
32,
|
||||||
|
8,
|
||||||
|
M_split,
|
||||||
|
N_split
|
||||||
|
>(
|
||||||
|
A, B, C, threadgroup_block, gid, sidx, lane_id
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
gemm_impl<
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
32,
|
||||||
|
32,
|
||||||
|
100,
|
||||||
|
M_split,
|
||||||
|
N_split
|
||||||
|
>(
|
||||||
|
A, B, C, threadgroup_block, gid, sidx, lane_id
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
@ -1476,19 +1476,27 @@ pub fn call_gemm(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let prefer_async_copy = !device.supports_family(MTLGPUFamily::Apple9);
|
let prefer_async_copy = !device.supports_family(MTLGPUFamily::Apple9);
|
||||||
|
|
||||||
let mut ideal_grouping = false;
|
|
||||||
let mut actual_groups: usize = 1;
|
let mut actual_groups: usize = 1;
|
||||||
actual_groups *= divide(m, 48) as usize;
|
actual_groups *= divide(m, 48) as usize;
|
||||||
actual_groups *= divide(n, 48) as usize;
|
actual_groups *= divide(n, 48) as usize;
|
||||||
actual_groups *= b;
|
actual_groups *= b;
|
||||||
|
|
||||||
let core_count = get_device_core_count(device);
|
let core_count = get_device_core_count(device);
|
||||||
println!("Core count: {}", core_count);
|
|
||||||
let ideal_grouping = if name == "sgemm" {
|
let ideal_grouping = if name == "sgemm" {
|
||||||
actual_groups <= core_count * 6
|
actual_groups <= core_count * 6
|
||||||
} else {
|
} else {
|
||||||
actual_groups <= core_count * 9
|
actual_groups <= core_count * 9
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut blockdim = (32, 32, 32);
|
||||||
|
if !ideal_grouping {
|
||||||
|
if name == "sgemm" {
|
||||||
|
blockdim = (48, 48, 24);
|
||||||
|
} else {
|
||||||
|
blockdim = (48, 48, 32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
assert!(rhs_stride.len() >= 2);
|
assert!(rhs_stride.len() >= 2);
|
||||||
assert!(lhs_stride.len() >= 2);
|
assert!(lhs_stride.len() >= 2);
|
||||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
@ -1525,52 +1533,45 @@ pub fn call_gemm(
|
|||||||
let alpha = 1.0f32;
|
let alpha = 1.0f32;
|
||||||
let beta = 0.0f32;
|
let beta = 0.0f32;
|
||||||
let batched = b > 1;
|
let batched = b > 1;
|
||||||
|
println!("batched: {batched}");
|
||||||
let fused_activation = false;
|
let fused_activation = false;
|
||||||
let fused_bias = false;
|
let fused_bias = false;
|
||||||
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
|
||||||
let m_simd = 8;
|
|
||||||
let n_simd = 8;
|
|
||||||
let k_simd = 64;
|
|
||||||
let m_splits = 1;
|
|
||||||
let n_splits = 1;
|
|
||||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
|
||||||
} else {
|
|
||||||
let m_simd = 40;
|
|
||||||
let n_simd = 40;
|
|
||||||
let k_simd = 32;
|
|
||||||
let m_splits = 1;
|
|
||||||
let n_splits = 1;
|
|
||||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
|
||||||
};
|
|
||||||
let constants = Some(ConstantValues::new(vec![
|
let constants = Some(ConstantValues::new(vec![
|
||||||
(0, Value::USize(m)),
|
(0, Value::USize(m)),
|
||||||
(1, Value::USize(n)),
|
(1, Value::USize(n)),
|
||||||
(2, Value::USize(k)),
|
(2, Value::USize(k)),
|
||||||
(10, Value::Bool(a_trans)),
|
(10, Value::Bool(a_trans)),
|
||||||
(11, Value::Bool(b_trans)),
|
(11, Value::Bool(b_trans)),
|
||||||
(13, Value::Bool(d_trans)),
|
//(13, Value::Bool(d_trans)),
|
||||||
(20, Value::F32(alpha)),
|
//(20, Value::F32(alpha)),
|
||||||
(21, Value::F32(beta)),
|
//(21, Value::F32(beta)),
|
||||||
(100, Value::Bool(batched)),
|
(100, Value::Bool(batched)),
|
||||||
(101, Value::Bool(fused_activation)),
|
//(101, Value::Bool(fused_activation)),
|
||||||
// Garbage
|
// Garbage
|
||||||
(102, Value::Bool(false)),
|
(102, Value::Bool(false)),
|
||||||
(103, Value::Bool(false)),
|
(103, Value::Bool(false)),
|
||||||
(113, Value::Bool(false)),
|
(113, Value::Bool(false)),
|
||||||
(50_000, Value::Bool(false)),
|
(50_000, Value::Bool(false)),
|
||||||
// End garbage
|
// End garbage
|
||||||
(200, Value::U16(32)),
|
//(200, Value::U16(blockdim.0)),
|
||||||
(201, Value::U16(32)),
|
//(201, Value::U16(blockdim.1)),
|
||||||
(202, Value::U16(32)),
|
//(202, Value::U16(blockdim.2)),
|
||||||
(206, Value::Bool(prefer_async_copy)),
|
(206, Value::Bool(prefer_async_copy)),
|
||||||
(207, Value::Bool(ideal_grouping)),
|
(207, Value::Bool(ideal_grouping)),
|
||||||
(210, Value::U16(m_splits)),
|
//(210, Value::U16(m_splits)),
|
||||||
(211, Value::U16(n_splits)),
|
//(211, Value::U16(n_splits)),
|
||||||
(50_001, Value::Bool(fused_bias)),
|
//(50_001, Value::Bool(fused_bias)),
|
||||||
]));
|
]));
|
||||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Candle, name, constants)?;
|
let pipeline = kernels.load_pipeline_with_constants(device, Source::Candle, name, constants)?;
|
||||||
let m_group = m_simd * m_splits;
|
|
||||||
let n_group = n_simd * n_splits;
|
let m_group: u16 = 32;
|
||||||
|
let n_group: u16 = 32;
|
||||||
|
let m_splits: u16 = 2;
|
||||||
|
let n_splits: u16 = 2;
|
||||||
|
let k_simd: u16 = 32;
|
||||||
|
let m_simd = m_group / m_splits;
|
||||||
|
let n_simd = n_group / n_splits;
|
||||||
|
|
||||||
let a_block_length = m_group * k_simd;
|
let a_block_length = m_group * k_simd;
|
||||||
let b_block_length = k_simd * n_group;
|
let b_block_length = k_simd * n_group;
|
||||||
@ -1580,6 +1581,7 @@ pub fn call_gemm(
|
|||||||
let c_block_length = m_group * n_group;
|
let c_block_length = m_group * n_group;
|
||||||
block_elements = std::cmp::max(c_block_length, block_elements)
|
block_elements = std::cmp::max(c_block_length, block_elements)
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
if fused_bias {
|
if fused_bias {
|
||||||
if d_trans {
|
if d_trans {
|
||||||
block_elements = std::cmp::max(block_elements, m_group);
|
block_elements = std::cmp::max(block_elements, m_group);
|
||||||
@ -1587,6 +1589,7 @@ pub fn call_gemm(
|
|||||||
block_elements = std::cmp::max(block_elements, n_group);
|
block_elements = std::cmp::max(block_elements, n_group);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
let bytes = match name {
|
let bytes = match name {
|
||||||
"sgemm" => 4,
|
"sgemm" => 4,
|
||||||
"hgemm" => 2,
|
"hgemm" => 2,
|
||||||
@ -1600,7 +1603,7 @@ pub fn call_gemm(
|
|||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
encoder.set_threadgroup_memory_length(0, block_bytes as NSUInteger);
|
||||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||||
encoder.set_buffer(2, Some(output), 0);
|
encoder.set_buffer(2, Some(output), 0);
|
||||||
@ -1614,7 +1617,7 @@ pub fn call_gemm(
|
|||||||
// TODO byte_stride_d
|
// TODO byte_stride_d
|
||||||
let byte_stride_d = 0;
|
let byte_stride_d = 0;
|
||||||
|
|
||||||
let buffer: Vec<u64> = vec![
|
let buffer: [u64; 4] = [
|
||||||
byte_stride_a as _,
|
byte_stride_a as _,
|
||||||
byte_stride_b as _,
|
byte_stride_b as _,
|
||||||
byte_stride_c as _,
|
byte_stride_c as _,
|
||||||
|
Binary file not shown.
@ -1100,6 +1100,11 @@ fn gemm() {
|
|||||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||||
let rhs_stride = vec![n * k, n, 1];
|
let rhs_stride = vec![n * k, n, 1];
|
||||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||||
|
println!("lhs: {lhs:?}");
|
||||||
|
println!("lhs_stride: {lhs_stride:?}");
|
||||||
|
println!("rhs: {rhs:?}");
|
||||||
|
println!("rhs_stride: {rhs_stride:?}");
|
||||||
|
|
||||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx(results, 4),
|
approx(results, 4),
|
||||||
@ -1111,6 +1116,11 @@ fn gemm() {
|
|||||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||||
let rhs_stride = vec![n * k, n, 1];
|
let rhs_stride = vec![n * k, n, 1];
|
||||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||||
|
println!("lhs: {lhs:?}");
|
||||||
|
println!("lhs_stride: {lhs_stride:?}");
|
||||||
|
println!("rhs: {rhs:?}");
|
||||||
|
println!("rhs_stride: {rhs_stride:?}");
|
||||||
|
|
||||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx(results, 4),
|
approx(results, 4),
|
||||||
|
Reference in New Issue
Block a user