mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
@ -70,6 +70,7 @@ impl Tensor {
|
||||
| Op::Sqr(node)
|
||||
| Op::Sqrt(node)
|
||||
| Op::Gelu(node)
|
||||
| Op::Relu(node)
|
||||
| Op::Exp(node)
|
||||
| Op::Log(node)
|
||||
| Op::Sin(node)
|
||||
@ -210,6 +211,7 @@ impl Tensor {
|
||||
}
|
||||
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
|
||||
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
|
||||
Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }),
|
||||
Op::Sqr(arg) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -37,6 +37,7 @@ pub(crate) enum Op {
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
Gelu(Tensor),
|
||||
Relu(Tensor),
|
||||
// TODO: Support for custom ops.
|
||||
}
|
||||
|
||||
@ -81,6 +82,7 @@ pub(crate) struct Neg;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
pub(crate) struct Relu;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr) => {
|
||||
@ -189,9 +191,33 @@ impl UnaryOp for Gelu {
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
const KERNEL_BF16: &'static str = "gelu_bf16";
|
||||
const KERNEL_F16: &'static str = "gelu_f16";
|
||||
const KERNEL_F32: &'static str = "gelu_f32";
|
||||
const KERNEL_F64: &'static str = "gelu_f64";
|
||||
const KERNEL_U32: &'static str = "gelu_u32";
|
||||
const KERNEL_BF16: &'static str = "ugelu_bf16";
|
||||
const KERNEL_F16: &'static str = "ugelu_f16";
|
||||
const KERNEL_F32: &'static str = "ugelu_f32";
|
||||
const KERNEL_F64: &'static str = "ugelu_f64";
|
||||
const KERNEL_U32: &'static str = "ugelu_u32";
|
||||
}
|
||||
|
||||
impl UnaryOp for Relu {
|
||||
const NAME: &'static str = "relu";
|
||||
const KERNEL_BF16: &'static str = "urelu_bf16";
|
||||
const KERNEL_F16: &'static str = "urelu_f16";
|
||||
const KERNEL_F32: &'static str = "urelu_f32";
|
||||
const KERNEL_F64: &'static str = "urelu_f64";
|
||||
const KERNEL_U32: &'static str = "urelu_u32";
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.max(bf16::ZERO)
|
||||
}
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.max(f16::ZERO)
|
||||
}
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.max(0f32)
|
||||
}
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.max(0f64)
|
||||
}
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
@ -320,6 +320,7 @@ impl Tensor {
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
unary_op!(gelu, Gelu);
|
||||
unary_op!(relu, Relu);
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
if self.rank() != 0 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
|
@ -26,13 +26,19 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
__device__ T gelu_fwd(T x) {
|
||||
__device__ __forceinline__ T gelu_fwd(T x) {
|
||||
T x_sq = x * x;
|
||||
T x_cube = x_sq * x;
|
||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanhg(static_cast<T>(M_2_SQRTPI * M_SQRT1_2) * alpha));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T relu_fwd(T x) {
|
||||
T zero = 0.;
|
||||
return maxg(x, zero);
|
||||
}
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
UNARY_OP(__half, ucopy_f16, x)
|
||||
@ -44,7 +50,8 @@ UNARY_OP(__half, ucos_f16, cosg(x))
|
||||
UNARY_OP(__half, uabs_f16, absg(x))
|
||||
UNARY_OP(__half, usqr_f16, x*x)
|
||||
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
||||
UNARY_OP(__half, gelu_f16, gelu_fwd(x))
|
||||
UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
#endif
|
||||
|
||||
UNARY_OP(float, ucopy_f32, x)
|
||||
@ -65,5 +72,7 @@ UNARY_OP(float, usqr_f32, x*x)
|
||||
UNARY_OP(double, usqr_f64, x*x)
|
||||
UNARY_OP(float, usqrt_f32, sqrtg(x))
|
||||
UNARY_OP(double, usqrt_f64, sqrtg(x))
|
||||
UNARY_OP(float, gelu_f32, gelu_fwd(x))
|
||||
UNARY_OP(double, gelu_f64, gelu_fwd(x))
|
||||
UNARY_OP(float, ugelu_f32, gelu_fwd(x))
|
||||
UNARY_OP(double, ugelu_f64, gelu_fwd(x))
|
||||
UNARY_OP(float, urelu_f32, relu_fwd(x))
|
||||
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
||||
|
Reference in New Issue
Block a user