mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add a couple cuda kernels from dfdx.
This commit is contained in:
21
kernels/src/binary_mul.cu
Normal file
21
kernels/src/binary_mul.cu
Normal file
@ -0,0 +1,21 @@
|
||||
#include "binary_op_macros.cuh"
|
||||
|
||||
struct BinaryMulKernalOp {};
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
BINARY_OP(__half, bmul_fwd_f16, bmul_bwd_lhs_f16, bmul_bwd_rhs_f16, BinaryMulKernalOp,
|
||||
x * y,
|
||||
y,
|
||||
x)
|
||||
#endif
|
||||
|
||||
BINARY_OP(float, bmul_fwd_f32, bmul_bwd_lhs_f32, bmul_bwd_rhs_f32, BinaryMulKernalOp,
|
||||
x * y,
|
||||
y,
|
||||
x)
|
||||
|
||||
BINARY_OP(double, bmul_fwd_f64, bmul_bwd_lhs_f64, bmul_bwd_rhs_f64, BinaryMulKernalOp,
|
||||
x * y,
|
||||
y,
|
||||
x)
|
||||
|
Reference in New Issue
Block a user