mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Add the kernels.
This commit is contained in:
@ -38,4 +38,5 @@ AFFINE_OP(__half, affine_f16)
|
||||
|
||||
AFFINE_OP(float, affine_f32)
|
||||
AFFINE_OP(double, affine_f64)
|
||||
AFFINE_OP(uint8_t, affine_u8)
|
||||
AFFINE_OP(uint32_t, affine_u32)
|
||||
|
@ -17,13 +17,17 @@ BINARY_OP(__half, bsub_f16, x - y)
|
||||
|
||||
BINARY_OP(float, badd_f32, x + y)
|
||||
BINARY_OP(double, badd_f64, x + y);
|
||||
BINARY_OP(uint8_t, badd_u8, x + y);
|
||||
BINARY_OP(uint32_t, badd_u32, x + y);
|
||||
BINARY_OP(float, bdiv_f32, x / y)
|
||||
BINARY_OP(double, bdiv_f64, x / y);
|
||||
BINARY_OP(uint8_t, bdiv_u8, x / y);
|
||||
BINARY_OP(uint32_t, bdiv_u32, x / y);
|
||||
BINARY_OP(float, bmul_f32, x * y)
|
||||
BINARY_OP(double, bmul_f64, x * y);
|
||||
BINARY_OP(uint8_t, bmul_u8, x * y);
|
||||
BINARY_OP(uint32_t, bmul_u32, x * y);
|
||||
BINARY_OP(float, bsub_f32, x - y)
|
||||
BINARY_OP(double, bsub_f64, x - y);
|
||||
BINARY_OP(uint8_t, bsub_u8, x - y);
|
||||
BINARY_OP(uint32_t, bsub_u32, x - y);
|
||||
|
@ -27,10 +27,12 @@ extern "C" __global__ void FN_NAME( \
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16)
|
||||
|
||||
CAST_OP(__nv_bfloat16, uint8_t, cast_bf16_u8)
|
||||
CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32)
|
||||
// CAST_OP(__nv_bfloat16, __half, cast_bf16_f16)
|
||||
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
|
||||
CAST_OP(__nv_bfloat16, double, cast_bf16_f64)
|
||||
CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16)
|
||||
CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16)
|
||||
// CAST_OP(__half, __nv_bfloat16, cast_f16_bf16)
|
||||
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
|
||||
@ -40,22 +42,32 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
CAST_OP(__half, __half, cast_f16_f16)
|
||||
|
||||
// CAST_OP(__half, uint8_t, cast_f16_u8 )
|
||||
CAST_OP(__half, uint32_t, cast_f16_u32)
|
||||
CAST_OP(__half, float, cast_f16_f32)
|
||||
CAST_OP(__half, double, cast_f16_f64)
|
||||
CAST_OP(uint8_t, __half, cast_u8_f16 )
|
||||
CAST_OP(uint32_t, __half, cast_u32_f16)
|
||||
CAST_OP(float, __half, cast_f32_f16)
|
||||
CAST_OP(double, __half, cast_f64_f16)
|
||||
#endif
|
||||
|
||||
CAST_OP(uint32_t, uint32_t, cast_u32_u32)
|
||||
CAST_OP(uint32_t, uint8_t, cast_u32_u8 )
|
||||
CAST_OP(uint32_t, float, cast_u32_f32)
|
||||
CAST_OP(uint32_t, double, cast_u32_f64)
|
||||
|
||||
CAST_OP(uint8_t, uint32_t, cast_u8_u32)
|
||||
CAST_OP(uint8_t, uint8_t, cast_u8_u8 )
|
||||
CAST_OP(uint8_t, float, cast_u8_f32)
|
||||
CAST_OP(uint8_t, double, cast_u8_f64)
|
||||
|
||||
CAST_OP(float, uint8_t, cast_f32_u8 )
|
||||
CAST_OP(float, uint32_t, cast_f32_u32)
|
||||
CAST_OP(float, float, cast_f32_f32)
|
||||
CAST_OP(float, double, cast_f32_f64)
|
||||
|
||||
CAST_OP(double, uint8_t, cast_f64_u8 )
|
||||
CAST_OP(double, uint32_t, cast_f64_u32)
|
||||
CAST_OP(double, float, cast_f64_f32)
|
||||
CAST_OP(double, double, cast_f64_f64)
|
||||
|
@ -39,4 +39,5 @@ EMB_OP(__half, emb_f16)
|
||||
|
||||
EMB_OP(float, emb_f32)
|
||||
EMB_OP(double, emb_f64)
|
||||
EMB_OP(uint8_t, emb_u8)
|
||||
EMB_OP(uint32_t, emb_u32)
|
||||
|
@ -42,4 +42,5 @@ WHERE_OP(__half, where_f16)
|
||||
|
||||
WHERE_OP(float, where_f32)
|
||||
WHERE_OP(double, where_f64)
|
||||
WHERE_OP(uint8_t, where_u8)
|
||||
WHERE_OP(uint32_t, where_u32)
|
||||
|
Reference in New Issue
Block a user