From aefca7f8e6cb198c2c7f3479c3590f523bd55d09 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 2 Sep 2024 12:33:58 +0200 Subject: [PATCH] Revert "Testing ushort intermediate in case combo of async and f16/bf16 is the issue" This reverts commit c3b07579959cd0839ff5752e26fcdd37f465f723. --- candle-metal-kernels/src/gemm.metal | 283 ++-------------------------- 1 file changed, 20 insertions(+), 263 deletions(-) diff --git a/candle-metal-kernels/src/gemm.metal b/candle-metal-kernels/src/gemm.metal index 751fd5e3..c5908ca9 100644 --- a/candle-metal-kernels/src/gemm.metal +++ b/candle-metal-kernels/src/gemm.metal @@ -48,15 +48,16 @@ namespace metal struct simdgroup_event { METAL_FUNC simdgroup_event() thread {} + template METAL_FUNC void async_copy( - threadgroup float *dst, - const device float *src, + threadgroup T *dst, + const device T *src, ulong n_elements ) { event = *__metal_simdgroup_async_copy_1d( // Description of the data type. - sizeof(float), - alignof(float), + sizeof(T), + alignof(T), // Description of the arguments. reinterpret_cast(dst), @@ -64,51 +65,16 @@ namespace metal n_elements); } + template METAL_FUNC void async_copy( - threadgroup bfloat *dst, - const device bfloat *src, - ulong n_elements - ) { - threadgroup ushort *re_dst = reinterpret_cast(dst); - const device ushort *re_src = reinterpret_cast(src); - event = *__metal_simdgroup_async_copy_1d( - // Description of the data type. - sizeof(ushort), - alignof(ushort), - - // Description of the arguments. - reinterpret_cast(re_dst), - reinterpret_cast(re_src), - n_elements); - } - - METAL_FUNC void async_copy( - threadgroup half *dst, - const device half *src, - ulong n_elements - ) { - threadgroup ushort *re_dst = reinterpret_cast(dst); - const device ushort *re_src = reinterpret_cast(src); - event = *__metal_simdgroup_async_copy_1d( - // Description of the data type. - sizeof(ushort), - alignof(ushort), - - // Description of the arguments. - reinterpret_cast(re_dst), - reinterpret_cast(re_src), - n_elements); - } - - METAL_FUNC void async_copy( - device float *dst, - const threadgroup float *src, + device T *dst, + const threadgroup T *src, ulong n_elements ) { event = *__metal_simdgroup_async_copy_1d( // Description of the data type. - sizeof(float), - alignof(float), + sizeof(T), + alignof(T), // Description of the arguments. reinterpret_cast(dst), @@ -116,50 +82,15 @@ namespace metal n_elements); } - METAL_FUNC void async_copy( - device bfloat *dst, - const threadgroup bfloat *src, - ulong n_elements - ) { - device ushort *re_dst = reinterpret_cast(dst); - const threadgroup ushort *re_src = reinterpret_cast(src); - event = *__metal_simdgroup_async_copy_1d( - // Description of the data type. - sizeof(ushort), - alignof(ushort), - - // Description of the arguments. - reinterpret_cast(re_dst), - reinterpret_cast(re_src), - n_elements); - } - - METAL_FUNC void async_copy( - device half *dst, - const threadgroup half *src, - ulong n_elements - ) { - device ushort *re_dst = reinterpret_cast(dst); - const threadgroup ushort *re_src = reinterpret_cast(src); - event = *__metal_simdgroup_async_copy_1d( - // Description of the data type. - sizeof(ushort), - alignof(ushort), - - // Description of the arguments. - reinterpret_cast(re_dst), - reinterpret_cast(re_src), - n_elements); - } - + template METAL_FUNC void async_copy( // Description of the destination. - threadgroup float *dst, + threadgroup T *dst, ushort dst_elements_per_row, ushort2 dst_tile_dimensions, // Description of the source. - const device float *src, + const device T *src, uint src_elements_per_row, ushort2 src_tile_dimensions, @@ -174,8 +105,8 @@ namespace metal } event = *__metal_simdgroup_async_copy_2d( // Description of the data type. - sizeof(float), - alignof(float), + sizeof(T), + alignof(T), // Description of the destination. reinterpret_cast(dst), @@ -194,104 +125,15 @@ namespace metal static_cast(clamp_mode)); } - METAL_FUNC void async_copy( - // Description of the destination. - threadgroup bfloat *dst, - ushort dst_elements_per_row, - ushort2 dst_tile_dimensions, - - // Description of the source. - const device bfloat *src, - uint src_elements_per_row, - ushort2 src_tile_dimensions, - - // Other arguments. - bool transpose_matrix = false, - simdgroup_async_copy_clamp_mode clamp_mode = - simdgroup_async_copy_clamp_mode::clamp_to_zero - ) thread { - if (transpose_matrix) { - src_tile_dimensions = src_tile_dimensions.yx; - dst_tile_dimensions = dst_tile_dimensions.yx; - } - - threadgroup ushort *re_dst = reinterpret_cast(dst); - const device ushort *re_src = reinterpret_cast(src); - event = *__metal_simdgroup_async_copy_2d( - // Description of the data type. - sizeof(ushort), - alignof(ushort), - - // Description of the destination. - reinterpret_cast(re_dst), - ushort(dst_elements_per_row), - 1, - ulong2(dst_tile_dimensions), - - // Description of the source. - reinterpret_cast(re_src), - uint(src_elements_per_row), - 1, - ulong2(src_tile_dimensions), - - // Other arguments. - long2(0), - static_cast(clamp_mode)); - } - + template METAL_FUNC void async_copy( // Description of the destination. - threadgroup half *dst, - ushort dst_elements_per_row, - ushort2 dst_tile_dimensions, - - // Description of the source. - const device half *src, - uint src_elements_per_row, - ushort2 src_tile_dimensions, - - // Other arguments. - bool transpose_matrix = false, - simdgroup_async_copy_clamp_mode clamp_mode = - simdgroup_async_copy_clamp_mode::clamp_to_zero - ) thread { - if (transpose_matrix) { - src_tile_dimensions = src_tile_dimensions.yx; - dst_tile_dimensions = dst_tile_dimensions.yx; - } - - threadgroup ushort *re_dst = reinterpret_cast(dst); - const device ushort *re_src = reinterpret_cast(src); - event = *__metal_simdgroup_async_copy_2d( - // Description of the data type. - sizeof(ushort), - alignof(ushort), - - // Description of the destination. - reinterpret_cast(re_dst), - ushort(dst_elements_per_row), - 1, - ulong2(dst_tile_dimensions), - - // Description of the source. - reinterpret_cast(re_src), - uint(src_elements_per_row), - 1, - ulong2(src_tile_dimensions), - - // Other arguments. - long2(0), - static_cast(clamp_mode)); - } - - METAL_FUNC void async_copy( - // Description of the destination. - device float *dst, + device T *dst, uint dst_elements_per_row, ushort2 dst_tile_dimensions, // Description of the source. - const threadgroup float *src, + const threadgroup T *src, ushort src_elements_per_row, ushort2 src_tile_dimensions, @@ -304,8 +146,8 @@ namespace metal } event = *__metal_simdgroup_async_copy_2d( // Description of the data type. - sizeof(float), - alignof(float), + sizeof(T), + alignof(T), // Description of the destination. reinterpret_cast(dst), @@ -324,90 +166,6 @@ namespace metal 0); } - METAL_FUNC void async_copy( - // Description of the destination. - device bfloat *dst, - uint dst_elements_per_row, - ushort2 dst_tile_dimensions, - - // Description of the source. - const threadgroup bfloat *src, - ushort src_elements_per_row, - ushort2 src_tile_dimensions, - - // Other arguments. - bool transpose_matrix = false - ) thread { - if (transpose_matrix) { - src_tile_dimensions = src_tile_dimensions.yx; - dst_tile_dimensions = dst_tile_dimensions.yx; - } - device ushort *re_dst = reinterpret_cast(dst); - const threadgroup ushort *re_src = reinterpret_cast(src); - event = *__metal_simdgroup_async_copy_2d( - // Description of the data type. - sizeof(ushort), - alignof(ushort), - - // Description of the destination. - reinterpret_cast(re_dst), - uint(dst_elements_per_row), - 1, - ulong2(dst_tile_dimensions), - - // Description of the source. - reinterpret_cast(re_src), - ushort(src_elements_per_row), - 1, - ulong2(src_tile_dimensions), - - // Other arguments. - long2(0), - 0); - } - - METAL_FUNC void async_copy( - // Description of the destination. - device half *dst, - uint dst_elements_per_row, - ushort2 dst_tile_dimensions, - - // Description of the source. - const threadgroup half *src, - ushort src_elements_per_row, - ushort2 src_tile_dimensions, - - // Other arguments. - bool transpose_matrix = false - ) thread { - if (transpose_matrix) { - src_tile_dimensions = src_tile_dimensions.yx; - dst_tile_dimensions = dst_tile_dimensions.yx; - } - device ushort *re_dst = reinterpret_cast(dst); - const threadgroup ushort *re_src = reinterpret_cast(src); - event = *__metal_simdgroup_async_copy_2d( - // Description of the data type. - sizeof(ushort), - alignof(ushort), - - // Description of the destination. - reinterpret_cast(re_dst), - uint(dst_elements_per_row), - 1, - ulong2(dst_tile_dimensions), - - // Description of the source. - reinterpret_cast(re_src), - ushort(src_elements_per_row), - 1, - ulong2(src_tile_dimensions), - - // Other arguments. - long2(0), - 0); - } - METAL_FUNC static void wait(int count, thread simdgroup_event *events) { __metal_wait_simdgroup_events(count, reinterpret_cast(events)); } @@ -955,7 +713,6 @@ void _gemm_impl(device T *A [[buffer(0)]], A_block_src += A_trans ? 8 * A_block_leading_dim : 8; B_block_src += B_trans ? 8 : 8 * B_block_leading_dim; } - threadgroup_barrier(mem_flags::mem_threadgroup); if (K_floor + K_simd < K) { #pragma clang loop unroll(full)