mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add the erf function. (#917)
This commit is contained in:
@ -442,6 +442,7 @@ impl Tensor {
|
|||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||||
|
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||||
}
|
}
|
||||||
|
@ -59,6 +59,7 @@ pub enum UnaryOp {
|
|||||||
Sqrt,
|
Sqrt,
|
||||||
Gelu,
|
Gelu,
|
||||||
GeluErf,
|
GeluErf,
|
||||||
|
Erf,
|
||||||
Relu,
|
Relu,
|
||||||
Tanh,
|
Tanh,
|
||||||
}
|
}
|
||||||
@ -327,6 +328,7 @@ pub(crate) struct Sqr;
|
|||||||
pub(crate) struct Sqrt;
|
pub(crate) struct Sqrt;
|
||||||
pub(crate) struct Gelu;
|
pub(crate) struct Gelu;
|
||||||
pub(crate) struct GeluErf;
|
pub(crate) struct GeluErf;
|
||||||
|
pub(crate) struct Erf;
|
||||||
pub(crate) struct Relu;
|
pub(crate) struct Relu;
|
||||||
pub(crate) struct Tanh;
|
pub(crate) struct Tanh;
|
||||||
|
|
||||||
@ -623,6 +625,40 @@ impl UnaryOpT for Gelu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl UnaryOpT for Erf {
|
||||||
|
const NAME: &'static str = "erf";
|
||||||
|
const KERNEL: &'static str = "uerf";
|
||||||
|
const V: Self = Erf;
|
||||||
|
#[inline(always)]
|
||||||
|
fn bf16(v: bf16) -> bf16 {
|
||||||
|
bf16::from_f64(Self::f64(v.to_f64()))
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f16(v: f16) -> f16 {
|
||||||
|
f16::from_f64(Self::f64(v.to_f64()))
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32(v: f32) -> f32 {
|
||||||
|
Self::f64(v as f64) as f32
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
crate::cpu::erf::erf(v)
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u8(_: u8) -> u8 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn u32(_: u32) -> u32 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
#[inline(always)]
|
||||||
|
fn i64(_: i64) -> i64 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UnaryOpT for GeluErf {
|
impl UnaryOpT for GeluErf {
|
||||||
const NAME: &'static str = "gelu_erf";
|
const NAME: &'static str = "gelu_erf";
|
||||||
const KERNEL: &'static str = "ugelu_erf";
|
const KERNEL: &'static str = "ugelu_erf";
|
||||||
|
@ -490,6 +490,7 @@ impl Tensor {
|
|||||||
unary_op!(sqrt, Sqrt);
|
unary_op!(sqrt, Sqrt);
|
||||||
unary_op!(gelu, Gelu);
|
unary_op!(gelu, Gelu);
|
||||||
unary_op!(gelu_erf, GeluErf);
|
unary_op!(gelu_erf, GeluErf);
|
||||||
|
unary_op!(erf, Erf);
|
||||||
unary_op!(relu, Relu);
|
unary_op!(relu, Relu);
|
||||||
|
|
||||||
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
||||||
|
@ -61,6 +61,13 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
|
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
|
||||||
|
[
|
||||||
|
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
|
||||||
|
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||||
|
]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user