mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the erf function. (#917)
This commit is contained in:
@ -442,6 +442,7 @@ impl Tensor {
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||
}
|
||||
|
@ -59,6 +59,7 @@ pub enum UnaryOp {
|
||||
Sqrt,
|
||||
Gelu,
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Tanh,
|
||||
}
|
||||
@ -327,6 +328,7 @@ pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Tanh;
|
||||
|
||||
@ -623,6 +625,40 @@ impl UnaryOpT for Gelu {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Erf {
|
||||
const NAME: &'static str = "erf";
|
||||
const KERNEL: &'static str = "uerf";
|
||||
const V: Self = Erf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
crate::cpu::erf::erf(v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for GeluErf {
|
||||
const NAME: &'static str = "gelu_erf";
|
||||
const KERNEL: &'static str = "ugelu_erf";
|
||||
|
@ -490,6 +490,7 @@ impl Tensor {
|
||||
unary_op!(sqrt, Sqrt);
|
||||
unary_op!(gelu, Gelu);
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
|
||||
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
||||
|
Reference in New Issue
Block a user