mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Creating Gelu op (no backward).
This commit is contained in:
@ -17,11 +17,22 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ T gelu_fwd(T x) {
|
||||||
|
constexpr T fastCoeff = 0.044715;
|
||||||
|
T x_sq = x * x;
|
||||||
|
T x_cube = x_sq * x;
|
||||||
|
T alpha = x + fastCoeff * x_cube;
|
||||||
|
return 0.5 * x * (1.0 + tanhg(M_2_SQRTPI * M_SQRT1_2 * alpha));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
UNARY_OP(__half, ucopy_f16, x)
|
UNARY_OP(__half, ucopy_f16, x)
|
||||||
UNARY_OP(__half, uneg_f16, -x)
|
UNARY_OP(__half, uneg_f16, -x)
|
||||||
UNARY_OP(__half, usqr_f16, x*x)
|
UNARY_OP(__half, usqr_f16, x*x)
|
||||||
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
||||||
|
// UNARY_OP(__half, gelu_f16, gelu_fwd(x))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
UNARY_OP(float, ucopy_f32, x)
|
UNARY_OP(float, ucopy_f32, x)
|
||||||
@ -32,3 +43,4 @@ UNARY_OP(float, usqr_f32, x*x)
|
|||||||
UNARY_OP(float, usqr_f64, x*x)
|
UNARY_OP(float, usqr_f64, x*x)
|
||||||
UNARY_OP(float, usqrt_f32, sqrtg(x))
|
UNARY_OP(float, usqrt_f32, sqrtg(x))
|
||||||
UNARY_OP(float, usqrt_f64, sqrtg(x))
|
UNARY_OP(float, usqrt_f64, sqrtg(x))
|
||||||
|
UNARY_OP(float, gelu_f32, gelu_fwd(x))
|
||||||
|
28
src/op.rs
28
src/op.rs
@ -22,6 +22,7 @@ pub(crate) enum Op {
|
|||||||
Sqrt(Tensor),
|
Sqrt(Tensor),
|
||||||
ToDevice(Tensor),
|
ToDevice(Tensor),
|
||||||
Transpose(Tensor, usize, usize),
|
Transpose(Tensor, usize, usize),
|
||||||
|
Gelu(Tensor),
|
||||||
// TODO: Support for custom ops.
|
// TODO: Support for custom ops.
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,6 +53,7 @@ pub(crate) struct Sub;
|
|||||||
pub(crate) struct Neg;
|
pub(crate) struct Neg;
|
||||||
pub(crate) struct Sqr;
|
pub(crate) struct Sqr;
|
||||||
pub(crate) struct Sqrt;
|
pub(crate) struct Sqrt;
|
||||||
|
pub(crate) struct Gelu;
|
||||||
|
|
||||||
impl BinaryOp for Add {
|
impl BinaryOp for Add {
|
||||||
const NAME: &'static str = "add";
|
const NAME: &'static str = "add";
|
||||||
@ -136,3 +138,29 @@ impl UnaryOp for Sqrt {
|
|||||||
const KERNEL_F32: &'static str = "usqrt_f32";
|
const KERNEL_F32: &'static str = "usqrt_f32";
|
||||||
const KERNEL_F64: &'static str = "usqrt_f64";
|
const KERNEL_F64: &'static str = "usqrt_f64";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `gelu` operation
|
||||||
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
|
#[inline]
|
||||||
|
pub fn gelu_f32(v: f32) -> f32 {
|
||||||
|
0.5 * (v)
|
||||||
|
* (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||||
|
}
|
||||||
|
/// `gelu` operation
|
||||||
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
|
#[inline]
|
||||||
|
pub fn gelu_f64(v: f64) -> f64 {
|
||||||
|
0.5 * (v)
|
||||||
|
* (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||||
|
}
|
||||||
|
impl UnaryOp for Gelu {
|
||||||
|
const NAME: &'static str = "gelu";
|
||||||
|
fn f32(v1: f32) -> f32 {
|
||||||
|
gelu_f32(v1)
|
||||||
|
}
|
||||||
|
fn f64(v1: f64) -> f64 {
|
||||||
|
gelu_f64(v1)
|
||||||
|
}
|
||||||
|
const KERNEL_F32: &'static str = "gelu_f32";
|
||||||
|
const KERNEL_F64: &'static str = "gelu_f64";
|
||||||
|
}
|
||||||
|
@ -240,6 +240,7 @@ impl Tensor {
|
|||||||
unary_op!(neg, Neg);
|
unary_op!(neg, Neg);
|
||||||
unary_op!(sqr, Sqr);
|
unary_op!(sqr, Sqr);
|
||||||
unary_op!(sqrt, Sqrt);
|
unary_op!(sqrt, Sqrt);
|
||||||
|
unary_op!(gelu, Gelu);
|
||||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||||
if self.rank() != 0 {
|
if self.rank() != 0 {
|
||||||
return Err(Error::UnexpectedNumberOfDims {
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
@ -766,6 +767,7 @@ impl Tensor {
|
|||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
| Op::Sqr(node)
|
| Op::Sqr(node)
|
||||||
| Op::Sqrt(node)
|
| Op::Sqrt(node)
|
||||||
|
| Op::Gelu(node)
|
||||||
| Op::Neg(node) => {
|
| Op::Neg(node) => {
|
||||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
@ -854,6 +856,7 @@ impl Tensor {
|
|||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
|
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
|
||||||
|
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "reshape" }),
|
||||||
Op::Sqr(arg) => {
|
Op::Sqr(arg) => {
|
||||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
Reference in New Issue
Block a user