mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Merge branch 'main' into ivarflakstad/metal-prng
This commit is contained in:
@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index(
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define AFFINE(FN_NAME, TYPENAME) \
|
||||
#define AFFINE(FN_NAME, T) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
device const T *input, \
|
||||
device T *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[id] = TYPENAME(float(input[id]) * mul + add); \
|
||||
output[id] = T(fma(float(input[id]), mul, add)); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
device const T *input, \
|
||||
device T *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
|
||||
output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \
|
||||
}
|
||||
|
||||
#define POWF(FN_NAME, TYPENAME) \
|
||||
|
@ -17,29 +17,45 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template<typename T, typename ID>
|
||||
METAL_FUNC void where_cond(
|
||||
constant size_t &numel,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t *strides_t,
|
||||
constant size_t *strides_f,
|
||||
device const ID *ids,
|
||||
device const T *t,
|
||||
device const T *f,
|
||||
device T *out,
|
||||
uint i [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (i >= numel){
|
||||
return;
|
||||
}
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f];
|
||||
}
|
||||
|
||||
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t *strides_t, \
|
||||
constant size_t *strides_f, \
|
||||
device const ID_TYPENAME *ids, \
|
||||
device const TYPENAME *t, \
|
||||
device const TYPENAME *f, \
|
||||
device TYPENAME *out ,\
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= numel){ \
|
||||
return; \
|
||||
} \
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
|
||||
} \
|
||||
#define WHERE_OP(T, ID, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t *strides_t, \
|
||||
constant size_t *strides_f, \
|
||||
device const ID *ids, \
|
||||
device const T *t, \
|
||||
device const T *f, \
|
||||
device T *out, \
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
|
||||
} \
|
||||
|
||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
||||
@ -54,10 +70,14 @@ kernel void FN_NAME( \
|
||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||
|
||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
||||
WHERE_OP(half, uint8_t, where_u8_f16)
|
||||
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
|
||||
#endif
|
Reference in New Issue
Block a user