mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add some missing where-cond kernels for metal. (#2203)
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
#include <metal_stdlib>
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
@ -57,27 +56,31 @@ kernel void FN_NAME(
|
||||
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
|
||||
} \
|
||||
|
||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
||||
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
|
||||
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
|
||||
// WHERE_OP(int64_t, int64_t, where_i64_i64)
|
||||
//
|
||||
// WHERE_OP(float, uint32_t, where_u32_f32)
|
||||
// WHERE_OP(double, uint32_t, where_u32_f64)
|
||||
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
|
||||
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
|
||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||
WHERE_OP(half, uint32_t, where_u32_f16)
|
||||
WHERE_OP(float, uint32_t, where_u32_f32)
|
||||
WHERE_OP(uint8_t, uint32_t, where_u32_u8)
|
||||
WHERE_OP(uint32_t, uint32_t, where_u32_u32)
|
||||
|
||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
WHERE_OP(half, uint8_t, where_u8_f16)
|
||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||
WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||
|
||||
WHERE_OP(half, int64_t, where_i64_f16)
|
||||
WHERE_OP(float, int64_t, where_i64_f32)
|
||||
WHERE_OP(uint8_t, int64_t, where_i64_u8)
|
||||
WHERE_OP(uint32_t, int64_t, where_i64_u32)
|
||||
WHERE_OP(int64_t, int64_t, where_i64_i64)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
WHERE_OP(bfloat, int64_t, where_i64_bf16)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
|
||||
#endif
|
||||
WHERE_OP(bfloat, uint32_t, where_u32_bf16)
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user