Revert "Testing ushort intermediate in case combo of async and f16/bf16 is the issue"

This reverts commit c3b0757995.
This commit is contained in:
Ivar Flakstad
2024-09-02 12:33:58 +02:00
parent c3b0757995
commit aefca7f8e6

View File

@ -48,15 +48,16 @@ namespace metal
struct simdgroup_event {
METAL_FUNC simdgroup_event() thread {}
template <typename T>
METAL_FUNC void async_copy(
threadgroup float *dst,
const device float *src,
threadgroup T *dst,
const device T *src,
ulong n_elements
) {
event = *__metal_simdgroup_async_copy_1d(
// Description of the data type.
sizeof(float),
alignof(float),
sizeof(T),
alignof(T),
// Description of the arguments.
reinterpret_cast<threadgroup void *>(dst),
@ -64,51 +65,16 @@ namespace metal
n_elements);
}
template <typename T>
METAL_FUNC void async_copy(
threadgroup bfloat *dst,
const device bfloat *src,
ulong n_elements
) {
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
event = *__metal_simdgroup_async_copy_1d(
// Description of the data type.
sizeof(ushort),
alignof(ushort),
// Description of the arguments.
reinterpret_cast<threadgroup void *>(re_dst),
reinterpret_cast<const device void *>(re_src),
n_elements);
}
METAL_FUNC void async_copy(
threadgroup half *dst,
const device half *src,
ulong n_elements
) {
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
event = *__metal_simdgroup_async_copy_1d(
// Description of the data type.
sizeof(ushort),
alignof(ushort),
// Description of the arguments.
reinterpret_cast<threadgroup void *>(re_dst),
reinterpret_cast<const device void *>(re_src),
n_elements);
}
METAL_FUNC void async_copy(
device float *dst,
const threadgroup float *src,
device T *dst,
const threadgroup T *src,
ulong n_elements
) {
event = *__metal_simdgroup_async_copy_1d(
// Description of the data type.
sizeof(float),
alignof(float),
sizeof(T),
alignof(T),
// Description of the arguments.
reinterpret_cast<device void *>(dst),
@ -116,50 +82,15 @@ namespace metal
n_elements);
}
METAL_FUNC void async_copy(
device bfloat *dst,
const threadgroup bfloat *src,
ulong n_elements
) {
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
event = *__metal_simdgroup_async_copy_1d(
// Description of the data type.
sizeof(ushort),
alignof(ushort),
// Description of the arguments.
reinterpret_cast<device void *>(re_dst),
reinterpret_cast<const threadgroup void *>(re_src),
n_elements);
}
METAL_FUNC void async_copy(
device half *dst,
const threadgroup half *src,
ulong n_elements
) {
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
event = *__metal_simdgroup_async_copy_1d(
// Description of the data type.
sizeof(ushort),
alignof(ushort),
// Description of the arguments.
reinterpret_cast<device void *>(re_dst),
reinterpret_cast<const threadgroup void *>(re_src),
n_elements);
}
template <typename T>
METAL_FUNC void async_copy(
// Description of the destination.
threadgroup float *dst,
threadgroup T *dst,
ushort dst_elements_per_row,
ushort2 dst_tile_dimensions,
// Description of the source.
const device float *src,
const device T *src,
uint src_elements_per_row,
ushort2 src_tile_dimensions,
@ -174,8 +105,8 @@ namespace metal
}
event = *__metal_simdgroup_async_copy_2d(
// Description of the data type.
sizeof(float),
alignof(float),
sizeof(T),
alignof(T),
// Description of the destination.
reinterpret_cast<threadgroup void *>(dst),
@ -194,104 +125,15 @@ namespace metal
static_cast<int>(clamp_mode));
}
METAL_FUNC void async_copy(
// Description of the destination.
threadgroup bfloat *dst,
ushort dst_elements_per_row,
ushort2 dst_tile_dimensions,
// Description of the source.
const device bfloat *src,
uint src_elements_per_row,
ushort2 src_tile_dimensions,
// Other arguments.
bool transpose_matrix = false,
simdgroup_async_copy_clamp_mode clamp_mode =
simdgroup_async_copy_clamp_mode::clamp_to_zero
) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
event = *__metal_simdgroup_async_copy_2d(
// Description of the data type.
sizeof(ushort),
alignof(ushort),
// Description of the destination.
reinterpret_cast<threadgroup void *>(re_dst),
ushort(dst_elements_per_row),
1,
ulong2(dst_tile_dimensions),
// Description of the source.
reinterpret_cast<const device void *>(re_src),
uint(src_elements_per_row),
1,
ulong2(src_tile_dimensions),
// Other arguments.
long2(0),
static_cast<int>(clamp_mode));
}
template <typename T>
METAL_FUNC void async_copy(
// Description of the destination.
threadgroup half *dst,
ushort dst_elements_per_row,
ushort2 dst_tile_dimensions,
// Description of the source.
const device half *src,
uint src_elements_per_row,
ushort2 src_tile_dimensions,
// Other arguments.
bool transpose_matrix = false,
simdgroup_async_copy_clamp_mode clamp_mode =
simdgroup_async_copy_clamp_mode::clamp_to_zero
) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
event = *__metal_simdgroup_async_copy_2d(
// Description of the data type.
sizeof(ushort),
alignof(ushort),
// Description of the destination.
reinterpret_cast<threadgroup void *>(re_dst),
ushort(dst_elements_per_row),
1,
ulong2(dst_tile_dimensions),
// Description of the source.
reinterpret_cast<const device void *>(re_src),
uint(src_elements_per_row),
1,
ulong2(src_tile_dimensions),
// Other arguments.
long2(0),
static_cast<int>(clamp_mode));
}
METAL_FUNC void async_copy(
// Description of the destination.
device float *dst,
device T *dst,
uint dst_elements_per_row,
ushort2 dst_tile_dimensions,
// Description of the source.
const threadgroup float *src,
const threadgroup T *src,
ushort src_elements_per_row,
ushort2 src_tile_dimensions,
@ -304,8 +146,8 @@ namespace metal
}
event = *__metal_simdgroup_async_copy_2d(
// Description of the data type.
sizeof(float),
alignof(float),
sizeof(T),
alignof(T),
// Description of the destination.
reinterpret_cast<device void *>(dst),
@ -324,90 +166,6 @@ namespace metal
0);
}
METAL_FUNC void async_copy(
// Description of the destination.
device bfloat *dst,
uint dst_elements_per_row,
ushort2 dst_tile_dimensions,
// Description of the source.
const threadgroup bfloat *src,
ushort src_elements_per_row,
ushort2 src_tile_dimensions,
// Other arguments.
bool transpose_matrix = false
) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
event = *__metal_simdgroup_async_copy_2d(
// Description of the data type.
sizeof(ushort),
alignof(ushort),
// Description of the destination.
reinterpret_cast<device void *>(re_dst),
uint(dst_elements_per_row),
1,
ulong2(dst_tile_dimensions),
// Description of the source.
reinterpret_cast<const threadgroup void *>(re_src),
ushort(src_elements_per_row),
1,
ulong2(src_tile_dimensions),
// Other arguments.
long2(0),
0);
}
METAL_FUNC void async_copy(
// Description of the destination.
device half *dst,
uint dst_elements_per_row,
ushort2 dst_tile_dimensions,
// Description of the source.
const threadgroup half *src,
ushort src_elements_per_row,
ushort2 src_tile_dimensions,
// Other arguments.
bool transpose_matrix = false
) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
event = *__metal_simdgroup_async_copy_2d(
// Description of the data type.
sizeof(ushort),
alignof(ushort),
// Description of the destination.
reinterpret_cast<device void *>(re_dst),
uint(dst_elements_per_row),
1,
ulong2(dst_tile_dimensions),
// Description of the source.
reinterpret_cast<const threadgroup void *>(re_src),
ushort(src_elements_per_row),
1,
ulong2(src_tile_dimensions),
// Other arguments.
long2(0),
0);
}
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
__metal_wait_simdgroup_events(count, reinterpret_cast<const thread _simdgroup_event_t**>(events));
}
@ -955,7 +713,6 @@ void _gemm_impl(device T *A [[buffer(0)]],
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (K_floor + K_simd < K) {
#pragma clang loop unroll(full)