mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Revert "Testing ushort intermediate in case combo of async and f16/bf16 is the issue"
This reverts commit c3b0757995
.
This commit is contained in:
@ -48,15 +48,16 @@ namespace metal
|
||||
struct simdgroup_event {
|
||||
METAL_FUNC simdgroup_event() thread {}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
threadgroup float *dst,
|
||||
const device float *src,
|
||||
threadgroup T *dst,
|
||||
const device T *src,
|
||||
ulong n_elements
|
||||
) {
|
||||
event = *__metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(float),
|
||||
alignof(float),
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<threadgroup void *>(dst),
|
||||
@ -64,51 +65,16 @@ namespace metal
|
||||
n_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
threadgroup bfloat *dst,
|
||||
const device bfloat *src,
|
||||
ulong n_elements
|
||||
) {
|
||||
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
|
||||
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
|
||||
event = *__metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(ushort),
|
||||
alignof(ushort),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<threadgroup void *>(re_dst),
|
||||
reinterpret_cast<const device void *>(re_src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
METAL_FUNC void async_copy(
|
||||
threadgroup half *dst,
|
||||
const device half *src,
|
||||
ulong n_elements
|
||||
) {
|
||||
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
|
||||
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
|
||||
event = *__metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(ushort),
|
||||
alignof(ushort),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<threadgroup void *>(re_dst),
|
||||
reinterpret_cast<const device void *>(re_src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
METAL_FUNC void async_copy(
|
||||
device float *dst,
|
||||
const threadgroup float *src,
|
||||
device T *dst,
|
||||
const threadgroup T *src,
|
||||
ulong n_elements
|
||||
) {
|
||||
event = *__metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(float),
|
||||
alignof(float),
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<device void *>(dst),
|
||||
@ -116,50 +82,15 @@ namespace metal
|
||||
n_elements);
|
||||
}
|
||||
|
||||
METAL_FUNC void async_copy(
|
||||
device bfloat *dst,
|
||||
const threadgroup bfloat *src,
|
||||
ulong n_elements
|
||||
) {
|
||||
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
|
||||
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
|
||||
event = *__metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(ushort),
|
||||
alignof(ushort),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<device void *>(re_dst),
|
||||
reinterpret_cast<const threadgroup void *>(re_src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
METAL_FUNC void async_copy(
|
||||
device half *dst,
|
||||
const threadgroup half *src,
|
||||
ulong n_elements
|
||||
) {
|
||||
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
|
||||
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
|
||||
event = *__metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(ushort),
|
||||
alignof(ushort),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<device void *>(re_dst),
|
||||
reinterpret_cast<const threadgroup void *>(re_src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
threadgroup float *dst,
|
||||
threadgroup T *dst,
|
||||
ushort dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const device float *src,
|
||||
const device T *src,
|
||||
uint src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
@ -174,8 +105,8 @@ namespace metal
|
||||
}
|
||||
event = *__metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(float),
|
||||
alignof(float),
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<threadgroup void *>(dst),
|
||||
@ -194,104 +125,15 @@ namespace metal
|
||||
static_cast<int>(clamp_mode));
|
||||
}
|
||||
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
threadgroup bfloat *dst,
|
||||
ushort dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const device bfloat *src,
|
||||
uint src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false,
|
||||
simdgroup_async_copy_clamp_mode clamp_mode =
|
||||
simdgroup_async_copy_clamp_mode::clamp_to_zero
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
|
||||
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
|
||||
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
|
||||
event = *__metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(ushort),
|
||||
alignof(ushort),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<threadgroup void *>(re_dst),
|
||||
ushort(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const device void *>(re_src),
|
||||
uint(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
static_cast<int>(clamp_mode));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
threadgroup half *dst,
|
||||
ushort dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const device half *src,
|
||||
uint src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false,
|
||||
simdgroup_async_copy_clamp_mode clamp_mode =
|
||||
simdgroup_async_copy_clamp_mode::clamp_to_zero
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
|
||||
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
|
||||
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
|
||||
event = *__metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(ushort),
|
||||
alignof(ushort),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<threadgroup void *>(re_dst),
|
||||
ushort(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const device void *>(re_src),
|
||||
uint(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
static_cast<int>(clamp_mode));
|
||||
}
|
||||
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
device float *dst,
|
||||
device T *dst,
|
||||
uint dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const threadgroup float *src,
|
||||
const threadgroup T *src,
|
||||
ushort src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
@ -304,8 +146,8 @@ namespace metal
|
||||
}
|
||||
event = *__metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(float),
|
||||
alignof(float),
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<device void *>(dst),
|
||||
@ -324,90 +166,6 @@ namespace metal
|
||||
0);
|
||||
}
|
||||
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
device bfloat *dst,
|
||||
uint dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const threadgroup bfloat *src,
|
||||
ushort src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
|
||||
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
|
||||
event = *__metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(ushort),
|
||||
alignof(ushort),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<device void *>(re_dst),
|
||||
uint(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const threadgroup void *>(re_src),
|
||||
ushort(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
0);
|
||||
}
|
||||
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
device half *dst,
|
||||
uint dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const threadgroup half *src,
|
||||
ushort src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
|
||||
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
|
||||
event = *__metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(ushort),
|
||||
alignof(ushort),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<device void *>(re_dst),
|
||||
uint(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const threadgroup void *>(re_src),
|
||||
ushort(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
0);
|
||||
}
|
||||
|
||||
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
|
||||
__metal_wait_simdgroup_events(count, reinterpret_cast<const thread _simdgroup_event_t**>(events));
|
||||
}
|
||||
@ -955,7 +713,6 @@ void _gemm_impl(device T *A [[buffer(0)]],
|
||||
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
|
||||
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (K_floor + K_simd < K) {
|
||||
#pragma clang loop unroll(full)
|
||||
|
Reference in New Issue
Block a user