diff --git a/candle-metal-kernels/src/gemm.metal b/candle-metal-kernels/src/gemm.metal index c5908ca9..c5c1475f 100644 --- a/candle-metal-kernels/src/gemm.metal +++ b/candle-metal-kernels/src/gemm.metal @@ -12,7 +12,8 @@ struct _simdgroup_event_t {}; thread _simdgroup_event_t* __metal_simdgroup_async_copy_1d( - ulong, ulong, threadgroup void *, const device void *, ulong) + ulong, ulong, + threadgroup void*, const device void*, ulong) __asm("air.simdgroup_async_copy_1d.p3i8.p1i8"); thread _simdgroup_event_t* __metal_simdgroup_async_copy_1d( @@ -53,15 +54,15 @@ namespace metal threadgroup T *dst, const device T *src, ulong n_elements - ) { + ) thread { event = *__metal_simdgroup_async_copy_1d( // Description of the data type. sizeof(T), alignof(T), // Description of the arguments. - reinterpret_cast(dst), - reinterpret_cast(src), + reinterpret_cast(dst), + reinterpret_cast(src), n_elements); } @@ -70,7 +71,7 @@ namespace metal device T *dst, const threadgroup T *src, ulong n_elements - ) { + ) thread { event = *__metal_simdgroup_async_copy_1d( // Description of the data type. sizeof(T), @@ -228,112 +229,87 @@ namespace metal } METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) { - if (transpose_matrix) { - return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y; - } else { - return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x; - } + if (transpose_matrix) { + return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y; + } else { + return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x; + } } METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { - if (transpose_matrix) { - return src + matrix_origin.x * elements_per_row + matrix_origin.y; - } else { - return src + matrix_origin.y * elements_per_row + matrix_origin.x; - } + if (transpose_matrix) { + return src + matrix_origin.x * elements_per_row + matrix_origin.y; + } else { + return src + matrix_origin.y * elements_per_row + matrix_origin.x; + } } // WARNING: All load and store functions assume the X dimension is divisible by 2. - template - METAL_FUNC void load(const device U *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { - if (transpose_matrix) { - uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y); - uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y); - U memoryForm0 = src[address0]; - U memoryForm1 = src[address1]; - ((thread T*)thread_elements())[0] = T(memoryForm0); - ((thread T*)thread_elements())[1] = T(memoryForm1); - } else if (elements_per_row % 2 != 0) { - uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0); - uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1); - U memoryForm0 = src[address0]; - U memoryForm1 = src[address1]; - ((thread T*)thread_elements())[0] = T(memoryForm0); - ((thread T*)thread_elements())[1] = T(memoryForm1); - } else { - auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0); - vec memoryForm = *(const device vec*)(src + combinedAddress); - *(thread_elements()) = vec(memoryForm); - } + + METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + *(thread_elements()) = vec(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]); + } else { + *(thread_elements()) = *reinterpret_cast*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x); + } } - template - METAL_FUNC void load(const threadgroup U *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { - if (transpose_matrix) { - ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y); - ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y); - U memoryForm0 = src[address0]; - U memoryForm1 = src[address1]; - ((thread T*)thread_elements())[0] = T(memoryForm0); - ((thread T*)thread_elements())[1] = T(memoryForm1); - } else if (elements_per_row % 2 != 0) { - ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0); - ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1); - U memoryForm0 = src[address0]; - U memoryForm1 = src[address1]; - ((thread T*)thread_elements())[0] = T(memoryForm0); - ((thread T*)thread_elements())[1] = T(memoryForm1); - } else { - auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0); - vec memoryForm = *(const threadgroup vec*)(src + combinedAddress); - *(thread_elements()) = vec(memoryForm); - } + METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + *(thread_elements()) = vec(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]); + } else { + *(thread_elements()) = *reinterpret_cast*>(src + matrix_origin.y * elements_per_row + matrix_origin.x); + } } - template - METAL_FUNC void store(device U *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { - if (transpose_matrix) { - uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y); - uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y); - T registerForm0 = ((thread T*)thread_elements())[0]; - T registerForm1 = ((thread T*)thread_elements())[1]; - dst[address0] = U(registerForm0); - dst[address1] = U(registerForm1); - } else if (elements_per_row % 2 != 0) { - uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0); - uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1); - T registerForm0 = ((thread T*)thread_elements())[0]; - T registerForm1 = ((thread T*)thread_elements())[1]; - dst[address0] = U(registerForm0); - dst[address1] = U(registerForm1); - } else { - auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0); - vec registerForm = *(thread_elements()); - *(device vec*)(dst + combinedAddress) = vec(registerForm); - } + METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y]; + } else { + thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x]; + } } - template - METAL_FUNC void store(threadgroup U *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { - if (transpose_matrix) { - ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y); - ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y); - T registerForm0 = ((thread T*)thread_elements())[0]; - T registerForm1 = ((thread T*)thread_elements())[1]; - dst[address0] = U(registerForm0); - dst[address1] = U(registerForm1); - } else if (elements_per_row % 2 != 0) { - ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0); - ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1); - T registerForm0 = ((thread T*)thread_elements())[0]; - T registerForm1 = ((thread T*)thread_elements())[1]; - dst[address0] = U(registerForm0); - dst[address1] = U(registerForm1); - } else { - auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0); - vec registerForm = *(thread_elements()); - *(threadgroup vec*)(dst + combinedAddress) = vec(registerForm); - } + METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y]; + } else { + thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x]; + } + } + + METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0]; + dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1]; + } else { + *reinterpret_cast*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements()); + } + } + + METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0]; + } else { + dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0]; + } + } + + METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1]; + } else { + dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1]; + } + } + + METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0]; + dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1]; + } else { + *reinterpret_cast*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements()); + } } template @@ -343,6 +319,46 @@ namespace metal } t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage::storage_type()); } + + // 'bfloat' is 'float' with the lower 16 bits set to garbage (BF15). + + METAL_FUNC thread ushort4* thread_elements_bfloat() thread { + thread float2* elements = thread_elements(); + return reinterpret_cast(elements); + } + + METAL_FUNC simdgroup_matrix_storage unpack_bfloat() thread { + ushort4 output; + thread ushort2& elements = thread_elements(); + output.y = elements[0]; + output.w = elements[1]; + return simdgroup_matrix_storage(as_type(output)); + } + + METAL_FUNC simdgroup_matrix_storage pack_bfloat() thread { + thread ushort4* elements = thread_elements_bfloat(); + return simdgroup_matrix_storage(ushort2(elements->y, elements->w)); + } + + METAL_FUNC void load_bfloat(const threadgroup ushort *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + thread_elements_bfloat()->y = src[matrix_origin.x * elements_per_row + matrix_origin.y]; + thread_elements_bfloat()->w = src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]; + } else { + thread_elements_bfloat()->zw = *reinterpret_cast(src + matrix_origin.y * elements_per_row + matrix_origin.x); + thread_elements_bfloat()->y = thread_elements_bfloat()->z; + } + } + + METAL_FUNC void store_bfloat(threadgroup ushort *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) { + if (transpose_matrix) { + dst[matrix_origin.x * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).y; + dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).w; + } else { + *(thread_elements_bfloat()).z = *(thread_elements_bfloat()).y; + *reinterpret_cast*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements_bfloat()).zw; + } + } }; } // namespace metal #pragma METAL internals : disable