mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Revert "Alter metal simdgroup matrix load/store ops"
This reverts commit 560f666d29
.
This commit is contained in:
@ -12,7 +12,8 @@
|
||||
struct _simdgroup_event_t {};
|
||||
|
||||
thread _simdgroup_event_t* __metal_simdgroup_async_copy_1d(
|
||||
ulong, ulong, threadgroup void *, const device void *, ulong)
|
||||
ulong, ulong,
|
||||
threadgroup void*, const device void*, ulong)
|
||||
__asm("air.simdgroup_async_copy_1d.p3i8.p1i8");
|
||||
|
||||
thread _simdgroup_event_t* __metal_simdgroup_async_copy_1d(
|
||||
@ -53,15 +54,15 @@ namespace metal
|
||||
threadgroup T *dst,
|
||||
const device T *src,
|
||||
ulong n_elements
|
||||
) {
|
||||
) thread {
|
||||
event = *__metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<threadgroup void *>(dst),
|
||||
reinterpret_cast<const device void *>(src),
|
||||
reinterpret_cast<threadgroup void*>(dst),
|
||||
reinterpret_cast<const device void*>(src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
@ -70,7 +71,7 @@ namespace metal
|
||||
device T *dst,
|
||||
const threadgroup T *src,
|
||||
ulong n_elements
|
||||
) {
|
||||
) thread {
|
||||
event = *__metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
@ -228,112 +229,87 @@ namespace metal
|
||||
}
|
||||
|
||||
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
|
||||
} else {
|
||||
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
|
||||
}
|
||||
if (transpose_matrix) {
|
||||
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
|
||||
} else {
|
||||
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
|
||||
} else {
|
||||
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
|
||||
}
|
||||
if (transpose_matrix) {
|
||||
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
|
||||
} else {
|
||||
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING: All load and store functions assume the X dimension is divisible by 2.
|
||||
template <typename U>
|
||||
METAL_FUNC void load(const device U *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
|
||||
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else {
|
||||
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
vec<U, 2> memoryForm = *(const device vec<U, 2>*)(src + combinedAddress);
|
||||
*(thread_elements()) = vec<T, 2>(memoryForm);
|
||||
}
|
||||
|
||||
METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
*(thread_elements()) = vec<T, 2>(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]);
|
||||
} else {
|
||||
*(thread_elements()) = *reinterpret_cast<const device vec<T, 2>*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
METAL_FUNC void load(const threadgroup U *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
|
||||
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else {
|
||||
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
vec<U, 2> memoryForm = *(const threadgroup vec<U, 2>*)(src + combinedAddress);
|
||||
*(thread_elements()) = vec<T, 2>(memoryForm);
|
||||
}
|
||||
METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
*(thread_elements()) = vec<T, 2>(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]);
|
||||
} else {
|
||||
*(thread_elements()) = *reinterpret_cast<const threadgroup vec<T, 2>*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
METAL_FUNC void store(device U *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
|
||||
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else {
|
||||
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
vec<T, 2> registerForm = *(thread_elements());
|
||||
*(device vec<U, 2>*)(dst + combinedAddress) = vec<U, 2>(registerForm);
|
||||
}
|
||||
METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
|
||||
} else {
|
||||
thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
METAL_FUNC void store(threadgroup U *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
|
||||
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else {
|
||||
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
vec<T, 2> registerForm = *(thread_elements());
|
||||
*(threadgroup vec<U, 2>*)(dst + combinedAddress) = vec<U, 2>(registerForm);
|
||||
}
|
||||
METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
|
||||
} else {
|
||||
thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
|
||||
dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
|
||||
} else {
|
||||
*reinterpret_cast<device vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
|
||||
} else {
|
||||
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0];
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
|
||||
} else {
|
||||
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1];
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0];
|
||||
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1];
|
||||
} else {
|
||||
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V>
|
||||
@ -343,6 +319,46 @@ namespace metal
|
||||
}
|
||||
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
|
||||
}
|
||||
|
||||
// 'bfloat' is 'float' with the lower 16 bits set to garbage (BF15).
|
||||
|
||||
METAL_FUNC thread ushort4* thread_elements_bfloat() thread {
|
||||
thread float2* elements = thread_elements();
|
||||
return reinterpret_cast<thread ushort4*>(elements);
|
||||
}
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage<float> unpack_bfloat() thread {
|
||||
ushort4 output;
|
||||
thread ushort2& elements = thread_elements();
|
||||
output.y = elements[0];
|
||||
output.w = elements[1];
|
||||
return simdgroup_matrix_storage(as_type<float2>(output));
|
||||
}
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage<ushort> pack_bfloat() thread {
|
||||
thread ushort4* elements = thread_elements_bfloat();
|
||||
return simdgroup_matrix_storage(ushort2(elements->y, elements->w));
|
||||
}
|
||||
|
||||
METAL_FUNC void load_bfloat(const threadgroup ushort *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
thread_elements_bfloat()->y = src[matrix_origin.x * elements_per_row + matrix_origin.y];
|
||||
thread_elements_bfloat()->w = src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y];
|
||||
} else {
|
||||
thread_elements_bfloat()->zw = *reinterpret_cast<const threadgroup ushort2*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
|
||||
thread_elements_bfloat()->y = thread_elements_bfloat()->z;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_bfloat(threadgroup ushort *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).y;
|
||||
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).w;
|
||||
} else {
|
||||
*(thread_elements_bfloat()).z = *(thread_elements_bfloat()).y;
|
||||
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements_bfloat()).zw;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace metal
|
||||
#pragma METAL internals : disable
|
||||
|
Reference in New Issue
Block a user