mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Revert "Testing ushort intermediate in case combo of async and f16/bf16 is the issue"
This reverts commit c3b0757995
.
This commit is contained in:
@ -48,15 +48,16 @@ namespace metal
|
|||||||
struct simdgroup_event {
|
struct simdgroup_event {
|
||||||
METAL_FUNC simdgroup_event() thread {}
|
METAL_FUNC simdgroup_event() thread {}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
METAL_FUNC void async_copy(
|
METAL_FUNC void async_copy(
|
||||||
threadgroup float *dst,
|
threadgroup T *dst,
|
||||||
const device float *src,
|
const device T *src,
|
||||||
ulong n_elements
|
ulong n_elements
|
||||||
) {
|
) {
|
||||||
event = *__metal_simdgroup_async_copy_1d(
|
event = *__metal_simdgroup_async_copy_1d(
|
||||||
// Description of the data type.
|
// Description of the data type.
|
||||||
sizeof(float),
|
sizeof(T),
|
||||||
alignof(float),
|
alignof(T),
|
||||||
|
|
||||||
// Description of the arguments.
|
// Description of the arguments.
|
||||||
reinterpret_cast<threadgroup void *>(dst),
|
reinterpret_cast<threadgroup void *>(dst),
|
||||||
@ -64,51 +65,16 @@ namespace metal
|
|||||||
n_elements);
|
n_elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
METAL_FUNC void async_copy(
|
METAL_FUNC void async_copy(
|
||||||
threadgroup bfloat *dst,
|
device T *dst,
|
||||||
const device bfloat *src,
|
const threadgroup T *src,
|
||||||
ulong n_elements
|
|
||||||
) {
|
|
||||||
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
|
|
||||||
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
|
|
||||||
event = *__metal_simdgroup_async_copy_1d(
|
|
||||||
// Description of the data type.
|
|
||||||
sizeof(ushort),
|
|
||||||
alignof(ushort),
|
|
||||||
|
|
||||||
// Description of the arguments.
|
|
||||||
reinterpret_cast<threadgroup void *>(re_dst),
|
|
||||||
reinterpret_cast<const device void *>(re_src),
|
|
||||||
n_elements);
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void async_copy(
|
|
||||||
threadgroup half *dst,
|
|
||||||
const device half *src,
|
|
||||||
ulong n_elements
|
|
||||||
) {
|
|
||||||
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
|
|
||||||
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
|
|
||||||
event = *__metal_simdgroup_async_copy_1d(
|
|
||||||
// Description of the data type.
|
|
||||||
sizeof(ushort),
|
|
||||||
alignof(ushort),
|
|
||||||
|
|
||||||
// Description of the arguments.
|
|
||||||
reinterpret_cast<threadgroup void *>(re_dst),
|
|
||||||
reinterpret_cast<const device void *>(re_src),
|
|
||||||
n_elements);
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void async_copy(
|
|
||||||
device float *dst,
|
|
||||||
const threadgroup float *src,
|
|
||||||
ulong n_elements
|
ulong n_elements
|
||||||
) {
|
) {
|
||||||
event = *__metal_simdgroup_async_copy_1d(
|
event = *__metal_simdgroup_async_copy_1d(
|
||||||
// Description of the data type.
|
// Description of the data type.
|
||||||
sizeof(float),
|
sizeof(T),
|
||||||
alignof(float),
|
alignof(T),
|
||||||
|
|
||||||
// Description of the arguments.
|
// Description of the arguments.
|
||||||
reinterpret_cast<device void *>(dst),
|
reinterpret_cast<device void *>(dst),
|
||||||
@ -116,50 +82,15 @@ namespace metal
|
|||||||
n_elements);
|
n_elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC void async_copy(
|
template <typename T>
|
||||||
device bfloat *dst,
|
|
||||||
const threadgroup bfloat *src,
|
|
||||||
ulong n_elements
|
|
||||||
) {
|
|
||||||
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
|
|
||||||
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
|
|
||||||
event = *__metal_simdgroup_async_copy_1d(
|
|
||||||
// Description of the data type.
|
|
||||||
sizeof(ushort),
|
|
||||||
alignof(ushort),
|
|
||||||
|
|
||||||
// Description of the arguments.
|
|
||||||
reinterpret_cast<device void *>(re_dst),
|
|
||||||
reinterpret_cast<const threadgroup void *>(re_src),
|
|
||||||
n_elements);
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void async_copy(
|
|
||||||
device half *dst,
|
|
||||||
const threadgroup half *src,
|
|
||||||
ulong n_elements
|
|
||||||
) {
|
|
||||||
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
|
|
||||||
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
|
|
||||||
event = *__metal_simdgroup_async_copy_1d(
|
|
||||||
// Description of the data type.
|
|
||||||
sizeof(ushort),
|
|
||||||
alignof(ushort),
|
|
||||||
|
|
||||||
// Description of the arguments.
|
|
||||||
reinterpret_cast<device void *>(re_dst),
|
|
||||||
reinterpret_cast<const threadgroup void *>(re_src),
|
|
||||||
n_elements);
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void async_copy(
|
METAL_FUNC void async_copy(
|
||||||
// Description of the destination.
|
// Description of the destination.
|
||||||
threadgroup float *dst,
|
threadgroup T *dst,
|
||||||
ushort dst_elements_per_row,
|
ushort dst_elements_per_row,
|
||||||
ushort2 dst_tile_dimensions,
|
ushort2 dst_tile_dimensions,
|
||||||
|
|
||||||
// Description of the source.
|
// Description of the source.
|
||||||
const device float *src,
|
const device T *src,
|
||||||
uint src_elements_per_row,
|
uint src_elements_per_row,
|
||||||
ushort2 src_tile_dimensions,
|
ushort2 src_tile_dimensions,
|
||||||
|
|
||||||
@ -174,8 +105,8 @@ namespace metal
|
|||||||
}
|
}
|
||||||
event = *__metal_simdgroup_async_copy_2d(
|
event = *__metal_simdgroup_async_copy_2d(
|
||||||
// Description of the data type.
|
// Description of the data type.
|
||||||
sizeof(float),
|
sizeof(T),
|
||||||
alignof(float),
|
alignof(T),
|
||||||
|
|
||||||
// Description of the destination.
|
// Description of the destination.
|
||||||
reinterpret_cast<threadgroup void *>(dst),
|
reinterpret_cast<threadgroup void *>(dst),
|
||||||
@ -194,104 +125,15 @@ namespace metal
|
|||||||
static_cast<int>(clamp_mode));
|
static_cast<int>(clamp_mode));
|
||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC void async_copy(
|
template <typename T>
|
||||||
// Description of the destination.
|
|
||||||
threadgroup bfloat *dst,
|
|
||||||
ushort dst_elements_per_row,
|
|
||||||
ushort2 dst_tile_dimensions,
|
|
||||||
|
|
||||||
// Description of the source.
|
|
||||||
const device bfloat *src,
|
|
||||||
uint src_elements_per_row,
|
|
||||||
ushort2 src_tile_dimensions,
|
|
||||||
|
|
||||||
// Other arguments.
|
|
||||||
bool transpose_matrix = false,
|
|
||||||
simdgroup_async_copy_clamp_mode clamp_mode =
|
|
||||||
simdgroup_async_copy_clamp_mode::clamp_to_zero
|
|
||||||
) thread {
|
|
||||||
if (transpose_matrix) {
|
|
||||||
src_tile_dimensions = src_tile_dimensions.yx;
|
|
||||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
|
|
||||||
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
|
|
||||||
event = *__metal_simdgroup_async_copy_2d(
|
|
||||||
// Description of the data type.
|
|
||||||
sizeof(ushort),
|
|
||||||
alignof(ushort),
|
|
||||||
|
|
||||||
// Description of the destination.
|
|
||||||
reinterpret_cast<threadgroup void *>(re_dst),
|
|
||||||
ushort(dst_elements_per_row),
|
|
||||||
1,
|
|
||||||
ulong2(dst_tile_dimensions),
|
|
||||||
|
|
||||||
// Description of the source.
|
|
||||||
reinterpret_cast<const device void *>(re_src),
|
|
||||||
uint(src_elements_per_row),
|
|
||||||
1,
|
|
||||||
ulong2(src_tile_dimensions),
|
|
||||||
|
|
||||||
// Other arguments.
|
|
||||||
long2(0),
|
|
||||||
static_cast<int>(clamp_mode));
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void async_copy(
|
METAL_FUNC void async_copy(
|
||||||
// Description of the destination.
|
// Description of the destination.
|
||||||
threadgroup half *dst,
|
device T *dst,
|
||||||
ushort dst_elements_per_row,
|
|
||||||
ushort2 dst_tile_dimensions,
|
|
||||||
|
|
||||||
// Description of the source.
|
|
||||||
const device half *src,
|
|
||||||
uint src_elements_per_row,
|
|
||||||
ushort2 src_tile_dimensions,
|
|
||||||
|
|
||||||
// Other arguments.
|
|
||||||
bool transpose_matrix = false,
|
|
||||||
simdgroup_async_copy_clamp_mode clamp_mode =
|
|
||||||
simdgroup_async_copy_clamp_mode::clamp_to_zero
|
|
||||||
) thread {
|
|
||||||
if (transpose_matrix) {
|
|
||||||
src_tile_dimensions = src_tile_dimensions.yx;
|
|
||||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup ushort *re_dst = reinterpret_cast<threadgroup ushort *>(dst);
|
|
||||||
const device ushort *re_src = reinterpret_cast<const device ushort *>(src);
|
|
||||||
event = *__metal_simdgroup_async_copy_2d(
|
|
||||||
// Description of the data type.
|
|
||||||
sizeof(ushort),
|
|
||||||
alignof(ushort),
|
|
||||||
|
|
||||||
// Description of the destination.
|
|
||||||
reinterpret_cast<threadgroup void *>(re_dst),
|
|
||||||
ushort(dst_elements_per_row),
|
|
||||||
1,
|
|
||||||
ulong2(dst_tile_dimensions),
|
|
||||||
|
|
||||||
// Description of the source.
|
|
||||||
reinterpret_cast<const device void *>(re_src),
|
|
||||||
uint(src_elements_per_row),
|
|
||||||
1,
|
|
||||||
ulong2(src_tile_dimensions),
|
|
||||||
|
|
||||||
// Other arguments.
|
|
||||||
long2(0),
|
|
||||||
static_cast<int>(clamp_mode));
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void async_copy(
|
|
||||||
// Description of the destination.
|
|
||||||
device float *dst,
|
|
||||||
uint dst_elements_per_row,
|
uint dst_elements_per_row,
|
||||||
ushort2 dst_tile_dimensions,
|
ushort2 dst_tile_dimensions,
|
||||||
|
|
||||||
// Description of the source.
|
// Description of the source.
|
||||||
const threadgroup float *src,
|
const threadgroup T *src,
|
||||||
ushort src_elements_per_row,
|
ushort src_elements_per_row,
|
||||||
ushort2 src_tile_dimensions,
|
ushort2 src_tile_dimensions,
|
||||||
|
|
||||||
@ -304,8 +146,8 @@ namespace metal
|
|||||||
}
|
}
|
||||||
event = *__metal_simdgroup_async_copy_2d(
|
event = *__metal_simdgroup_async_copy_2d(
|
||||||
// Description of the data type.
|
// Description of the data type.
|
||||||
sizeof(float),
|
sizeof(T),
|
||||||
alignof(float),
|
alignof(T),
|
||||||
|
|
||||||
// Description of the destination.
|
// Description of the destination.
|
||||||
reinterpret_cast<device void *>(dst),
|
reinterpret_cast<device void *>(dst),
|
||||||
@ -324,90 +166,6 @@ namespace metal
|
|||||||
0);
|
0);
|
||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC void async_copy(
|
|
||||||
// Description of the destination.
|
|
||||||
device bfloat *dst,
|
|
||||||
uint dst_elements_per_row,
|
|
||||||
ushort2 dst_tile_dimensions,
|
|
||||||
|
|
||||||
// Description of the source.
|
|
||||||
const threadgroup bfloat *src,
|
|
||||||
ushort src_elements_per_row,
|
|
||||||
ushort2 src_tile_dimensions,
|
|
||||||
|
|
||||||
// Other arguments.
|
|
||||||
bool transpose_matrix = false
|
|
||||||
) thread {
|
|
||||||
if (transpose_matrix) {
|
|
||||||
src_tile_dimensions = src_tile_dimensions.yx;
|
|
||||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
|
||||||
}
|
|
||||||
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
|
|
||||||
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
|
|
||||||
event = *__metal_simdgroup_async_copy_2d(
|
|
||||||
// Description of the data type.
|
|
||||||
sizeof(ushort),
|
|
||||||
alignof(ushort),
|
|
||||||
|
|
||||||
// Description of the destination.
|
|
||||||
reinterpret_cast<device void *>(re_dst),
|
|
||||||
uint(dst_elements_per_row),
|
|
||||||
1,
|
|
||||||
ulong2(dst_tile_dimensions),
|
|
||||||
|
|
||||||
// Description of the source.
|
|
||||||
reinterpret_cast<const threadgroup void *>(re_src),
|
|
||||||
ushort(src_elements_per_row),
|
|
||||||
1,
|
|
||||||
ulong2(src_tile_dimensions),
|
|
||||||
|
|
||||||
// Other arguments.
|
|
||||||
long2(0),
|
|
||||||
0);
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void async_copy(
|
|
||||||
// Description of the destination.
|
|
||||||
device half *dst,
|
|
||||||
uint dst_elements_per_row,
|
|
||||||
ushort2 dst_tile_dimensions,
|
|
||||||
|
|
||||||
// Description of the source.
|
|
||||||
const threadgroup half *src,
|
|
||||||
ushort src_elements_per_row,
|
|
||||||
ushort2 src_tile_dimensions,
|
|
||||||
|
|
||||||
// Other arguments.
|
|
||||||
bool transpose_matrix = false
|
|
||||||
) thread {
|
|
||||||
if (transpose_matrix) {
|
|
||||||
src_tile_dimensions = src_tile_dimensions.yx;
|
|
||||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
|
||||||
}
|
|
||||||
device ushort *re_dst = reinterpret_cast<device ushort *>(dst);
|
|
||||||
const threadgroup ushort *re_src = reinterpret_cast<const threadgroup ushort *>(src);
|
|
||||||
event = *__metal_simdgroup_async_copy_2d(
|
|
||||||
// Description of the data type.
|
|
||||||
sizeof(ushort),
|
|
||||||
alignof(ushort),
|
|
||||||
|
|
||||||
// Description of the destination.
|
|
||||||
reinterpret_cast<device void *>(re_dst),
|
|
||||||
uint(dst_elements_per_row),
|
|
||||||
1,
|
|
||||||
ulong2(dst_tile_dimensions),
|
|
||||||
|
|
||||||
// Description of the source.
|
|
||||||
reinterpret_cast<const threadgroup void *>(re_src),
|
|
||||||
ushort(src_elements_per_row),
|
|
||||||
1,
|
|
||||||
ulong2(src_tile_dimensions),
|
|
||||||
|
|
||||||
// Other arguments.
|
|
||||||
long2(0),
|
|
||||||
0);
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
|
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
|
||||||
__metal_wait_simdgroup_events(count, reinterpret_cast<const thread _simdgroup_event_t**>(events));
|
__metal_wait_simdgroup_events(count, reinterpret_cast<const thread _simdgroup_event_t**>(events));
|
||||||
}
|
}
|
||||||
@ -955,7 +713,6 @@ void _gemm_impl(device T *A [[buffer(0)]],
|
|||||||
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
|
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
|
||||||
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
|
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
if (K_floor + K_simd < K) {
|
if (K_floor + K_simd < K) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
|
Reference in New Issue
Block a user