Revert "Alter metal simdgroup matrix load/store ops"

This reverts commit 560f666d29.
This commit is contained in:
Ivar Flakstad
2024-09-02 12:34:09 +02:00
parent aefca7f8e6
commit f9b2bb4d46

View File

@ -12,7 +12,8 @@
struct _simdgroup_event_t {};
thread _simdgroup_event_t* __metal_simdgroup_async_copy_1d(
ulong, ulong, threadgroup void *, const device void *, ulong)
ulong, ulong,
threadgroup void*, const device void*, ulong)
__asm("air.simdgroup_async_copy_1d.p3i8.p1i8");
thread _simdgroup_event_t* __metal_simdgroup_async_copy_1d(
@ -53,15 +54,15 @@ namespace metal
threadgroup T *dst,
const device T *src,
ulong n_elements
) {
) thread {
event = *__metal_simdgroup_async_copy_1d(
// Description of the data type.
sizeof(T),
alignof(T),
// Description of the arguments.
reinterpret_cast<threadgroup void *>(dst),
reinterpret_cast<const device void *>(src),
reinterpret_cast<threadgroup void*>(dst),
reinterpret_cast<const device void*>(src),
n_elements);
}
@ -70,7 +71,7 @@ namespace metal
device T *dst,
const threadgroup T *src,
ulong n_elements
) {
) thread {
event = *__metal_simdgroup_async_copy_1d(
// Description of the data type.
sizeof(T),
@ -228,112 +229,87 @@ namespace metal
}
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
} else {
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
}
if (transpose_matrix) {
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
} else {
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
}
}
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
} else {
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
}
if (transpose_matrix) {
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
} else {
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
}
}
// WARNING: All load and store functions assume the X dimension is divisible by 2.
template <typename U>
METAL_FUNC void load(const device U *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
U memoryForm0 = src[address0];
U memoryForm1 = src[address1];
((thread T*)thread_elements())[0] = T(memoryForm0);
((thread T*)thread_elements())[1] = T(memoryForm1);
} else if (elements_per_row % 2 != 0) {
uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
U memoryForm0 = src[address0];
U memoryForm1 = src[address1];
((thread T*)thread_elements())[0] = T(memoryForm0);
((thread T*)thread_elements())[1] = T(memoryForm1);
} else {
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
vec<U, 2> memoryForm = *(const device vec<U, 2>*)(src + combinedAddress);
*(thread_elements()) = vec<T, 2>(memoryForm);
}
METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
*(thread_elements()) = vec<T, 2>(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]);
} else {
*(thread_elements()) = *reinterpret_cast<const device vec<T, 2>*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x);
}
}
template <typename U>
METAL_FUNC void load(const threadgroup U *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
U memoryForm0 = src[address0];
U memoryForm1 = src[address1];
((thread T*)thread_elements())[0] = T(memoryForm0);
((thread T*)thread_elements())[1] = T(memoryForm1);
} else if (elements_per_row % 2 != 0) {
ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
U memoryForm0 = src[address0];
U memoryForm1 = src[address1];
((thread T*)thread_elements())[0] = T(memoryForm0);
((thread T*)thread_elements())[1] = T(memoryForm1);
} else {
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
vec<U, 2> memoryForm = *(const threadgroup vec<U, 2>*)(src + combinedAddress);
*(thread_elements()) = vec<T, 2>(memoryForm);
}
METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
*(thread_elements()) = vec<T, 2>(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]);
} else {
*(thread_elements()) = *reinterpret_cast<const threadgroup vec<T, 2>*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
}
}
template <typename U>
METAL_FUNC void store(device U *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
T registerForm0 = ((thread T*)thread_elements())[0];
T registerForm1 = ((thread T*)thread_elements())[1];
dst[address0] = U(registerForm0);
dst[address1] = U(registerForm1);
} else if (elements_per_row % 2 != 0) {
uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
T registerForm0 = ((thread T*)thread_elements())[0];
T registerForm1 = ((thread T*)thread_elements())[1];
dst[address0] = U(registerForm0);
dst[address1] = U(registerForm1);
} else {
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
vec<T, 2> registerForm = *(thread_elements());
*(device vec<U, 2>*)(dst + combinedAddress) = vec<U, 2>(registerForm);
}
METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
} else {
thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
}
}
template <typename U>
METAL_FUNC void store(threadgroup U *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
T registerForm0 = ((thread T*)thread_elements())[0];
T registerForm1 = ((thread T*)thread_elements())[1];
dst[address0] = U(registerForm0);
dst[address1] = U(registerForm1);
} else if (elements_per_row % 2 != 0) {
ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
T registerForm0 = ((thread T*)thread_elements())[0];
T registerForm1 = ((thread T*)thread_elements())[1];
dst[address0] = U(registerForm0);
dst[address1] = U(registerForm1);
} else {
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
vec<T, 2> registerForm = *(thread_elements());
*(threadgroup vec<U, 2>*)(dst + combinedAddress) = vec<U, 2>(registerForm);
}
METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
} else {
thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
}
}
METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
} else {
*reinterpret_cast<device vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
}
}
METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
} else {
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0];
}
}
METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
} else {
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1];
}
}
METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0];
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1];
} else {
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
}
}
template <typename U, typename V>
@ -343,6 +319,46 @@ namespace metal
}
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
}
// 'bfloat' is 'float' with the lower 16 bits set to garbage (BF15).
METAL_FUNC thread ushort4* thread_elements_bfloat() thread {
thread float2* elements = thread_elements();
return reinterpret_cast<thread ushort4*>(elements);
}
METAL_FUNC simdgroup_matrix_storage<float> unpack_bfloat() thread {
ushort4 output;
thread ushort2& elements = thread_elements();
output.y = elements[0];
output.w = elements[1];
return simdgroup_matrix_storage(as_type<float2>(output));
}
METAL_FUNC simdgroup_matrix_storage<ushort> pack_bfloat() thread {
thread ushort4* elements = thread_elements_bfloat();
return simdgroup_matrix_storage(ushort2(elements->y, elements->w));
}
METAL_FUNC void load_bfloat(const threadgroup ushort *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements_bfloat()->y = src[matrix_origin.x * elements_per_row + matrix_origin.y];
thread_elements_bfloat()->w = src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y];
} else {
thread_elements_bfloat()->zw = *reinterpret_cast<const threadgroup ushort2*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
thread_elements_bfloat()->y = thread_elements_bfloat()->z;
}
}
METAL_FUNC void store_bfloat(threadgroup ushort *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).y;
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).w;
} else {
*(thread_elements_bfloat()).z = *(thread_elements_bfloat()).y;
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements_bfloat()).zw;
}
}
};
} // namespace metal
#pragma METAL internals : disable