diff --git a/candle-metal-kernels/src/gemm.metal b/candle-metal-kernels/src/gemm.metal index c5c1475f..9b1d40aa 100644 --- a/candle-metal-kernels/src/gemm.metal +++ b/candle-metal-kernels/src/gemm.metal @@ -9,34 +9,72 @@ #ifndef __METAL_SIMDGROUP_EVENT #define __METAL_SIMDGROUP_EVENT -struct _simdgroup_event_t {}; +// Invoking the generation of LLVM bitcode for async copies. +// +// %struct._simdgroup_event_t = type opaque +// +struct _simdgroup_event_t; -thread _simdgroup_event_t* __metal_simdgroup_async_copy_1d( - ulong, ulong, - threadgroup void*, const device void*, ulong) - __asm("air.simdgroup_async_copy_1d.p3i8.p1i8"); +// Invoking the generation of LLVM bitcode for async copies. +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_1d( + ulong, ulong, threadgroup void *, const device void *, ulong) + __asm("air.simdgroup_async_copy_1d.p3i8.p1i8"); -thread _simdgroup_event_t* __metal_simdgroup_async_copy_1d( - ulong, ulong, - device void*, const threadgroup void*, ulong) - __asm("air.simdgroup_async_copy_1d.p1i8.p3i8"); +// Invoking the generation of LLVM bitcode for async copies. +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_1d( + ulong, ulong, device void *, const threadgroup void *, ulong) + __asm("air.simdgroup_async_copy_1d.p1i8.p3i8"); -thread _simdgroup_event_t* __metal_simdgroup_async_copy_2d( - ulong, ulong, - threadgroup void*, ulong, ulong, ulong2, - const device void*, ulong, ulong, ulong2, - long2, int) - __asm("air.simdgroup_async_copy_2d.p3i8.p1i8"); +// Invoking the generation of LLVM bitcode for async copies. +// +// ; Function Attrs: argmemonly convergent nounwind +// declare %struct._simdgroup_event_t* +// @air.simdgroup_async_copy_2d.p3i8.p1i8( +// i64, i64, +// i8 addrspace(3)* nocapture writeonly, i64, i64, <2 x i64>, +// i8 addrspace(1)* nocapture readonly, i64, i64, <2 x i64>, +// <2 x i64>, i32) +// local_unnamed_addr #4 +// +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_2d( + ulong, ulong, + threadgroup void *, ulong, ulong, ulong2, + const device void *, ulong, ulong, ulong2, + long2, int) + __asm("air.simdgroup_async_copy_2d.p3i8.p1i8"); -thread _simdgroup_event_t* __metal_simdgroup_async_copy_2d( - ulong, ulong, - device void*, ulong, ulong, ulong2, - const threadgroup void*, ulong, ulong, ulong2, - long2, int) - __asm("air.simdgroup_async_copy_2d.p1i8.p3i8"); +// Invoking the generation of LLVM bitcode for async copies. +// +// ; Function Attrs: argmemonly convergent nounwind +// declare %struct._simdgroup_event_t* +// @air.simdgroup_async_copy_2d.p1i8.p3i8( +// i64, i64, +// i8 addrspace(1)* nocapture writeonly, i64, i64, <2 x i64>, +// i8 addrspace(3)* nocapture readonly, i64, i64, <2 x i64>, +// <2 x i64>, i32) +// local_unnamed_addr #4 +// +thread _simdgroup_event_t* +__metal_simdgroup_async_copy_2d( + ulong, ulong, + device void *, ulong, ulong, ulong2, + const threadgroup void *, ulong, ulong, ulong2, + long2, int) + __asm("air.simdgroup_async_copy_2d.p1i8.p3i8"); -void __metal_wait_simdgroup_events(int, const thread _simdgroup_event_t**) - __asm("air.wait_simdgroup_events"); +// Invoking the generation of LLVM bitcode for async copies. +// +// ; Function Attrs: convergent nounwind +// declare void +// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture) +// local_unnamed_addr #3 +// +void __metal_wait_simdgroup_events( + int, thread _simdgroup_event_t**) + __asm("air.wait_simdgroup_events"); #pragma METAL internals : enable namespace metal @@ -55,14 +93,14 @@ namespace metal const device T *src, ulong n_elements ) thread { - event = *__metal_simdgroup_async_copy_1d( + event = __metal_simdgroup_async_copy_1d( // Description of the data type. sizeof(T), alignof(T), // Description of the arguments. - reinterpret_cast(dst), - reinterpret_cast(src), + reinterpret_cast(dst), + reinterpret_cast(src), n_elements); } @@ -72,7 +110,7 @@ namespace metal const threadgroup T *src, ulong n_elements ) thread { - event = *__metal_simdgroup_async_copy_1d( + event = __metal_simdgroup_async_copy_1d( // Description of the data type. sizeof(T), alignof(T), @@ -104,7 +142,7 @@ namespace metal src_tile_dimensions = src_tile_dimensions.yx; dst_tile_dimensions = dst_tile_dimensions.yx; } - event = *__metal_simdgroup_async_copy_2d( + event = __metal_simdgroup_async_copy_2d( // Description of the data type. sizeof(T), alignof(T), @@ -145,7 +183,7 @@ namespace metal src_tile_dimensions = src_tile_dimensions.yx; dst_tile_dimensions = dst_tile_dimensions.yx; } - event = *__metal_simdgroup_async_copy_2d( + event = __metal_simdgroup_async_copy_2d( // Description of the data type. sizeof(T), alignof(T), @@ -168,11 +206,16 @@ namespace metal } METAL_FUNC static void wait(int count, thread simdgroup_event *events) { - __metal_wait_simdgroup_events(count, reinterpret_cast(events)); + __metal_wait_simdgroup_events( + count, reinterpret_cast(events)); } private: - thread _simdgroup_event_t event; + // Invoking the generation of LLVM bitcode for async copies. + // + // %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* } + // + thread _simdgroup_event_t* event; }; } // namespace metal #pragma METAL internals : disable