mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Add erf.
This commit is contained in:
@ -346,9 +346,8 @@ impl BackendStorage for MetalStorage {
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
// TODO erf does not exist in metal
|
||||
("ugelu_erf", DType::F32) => crate::bail!("erf is not implemented in metal"),
|
||||
("uerf", DType::F32) => crate::bail!("erf is not implemented in metal"),
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
@ -360,9 +359,8 @@ impl BackendStorage for MetalStorage {
|
||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||
("ulog", DType::F16) => contiguous::log::HALF,
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
// TODO erf does not exist in metal
|
||||
("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"),
|
||||
("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"),
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
@ -389,9 +387,8 @@ impl BackendStorage for MetalStorage {
|
||||
("uexp", DType::F32) => strided::exp::FLOAT,
|
||||
("ulog", DType::F32) => strided::log::FLOAT,
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
// TODO erf does not exist in metal
|
||||
("ugelu_erf", DType::F32) => crate::bail!("erf is not implemented in metal"),
|
||||
("uerf", DType::F32) => crate::bail!("erf is not implemented in metal"),
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
@ -403,9 +400,8 @@ impl BackendStorage for MetalStorage {
|
||||
("uexp", DType::F16) => strided::exp::HALF,
|
||||
("ulog", DType::F16) => strided::log::HALF,
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
// TODO erf does not exist in metal
|
||||
("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"),
|
||||
("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"),
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
|
@ -152,7 +152,7 @@ macro_rules! ops{
|
||||
}
|
||||
|
||||
pub mod unary {
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round);
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf);
|
||||
}
|
||||
pub mod binary {
|
||||
ops!(add, sub, mul, div);
|
||||
|
@ -20,7 +20,30 @@ METAL_FUNC uint get_strided_index(
|
||||
|
||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||
template <typename T> METAL_FUNC T erf(T in){
|
||||
float x = (float) in;
|
||||
// constants
|
||||
float a1 = 0.254829592;
|
||||
float a2 = -0.284496736;
|
||||
float a3 = 1.421413741;
|
||||
float a4 = -1.453152027;
|
||||
float a5 = 1.061405429;
|
||||
float p = 0.3275911;
|
||||
|
||||
// Save the sign of x
|
||||
int sign = 1;
|
||||
if (x < 0)
|
||||
sign = -1;
|
||||
x = fabs(x);
|
||||
|
||||
// A&S formula 7.1.26
|
||||
float t = 1.0/(1.0 + p*x);
|
||||
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
||||
|
||||
return (T) sign*y;
|
||||
}
|
||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||
template <typename T> METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; }
|
||||
template <typename T> METAL_FUNC T gelu(T x){
|
||||
T x_sq = x * x;
|
||||
T x_cube = x_sq * x;
|
||||
@ -77,6 +100,8 @@ UNARY_OP(gelu)
|
||||
UNARY_OP(ceil)
|
||||
UNARY_OP(floor)
|
||||
UNARY_OP(round)
|
||||
UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
@ -94,6 +119,8 @@ BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
BFLOAT_UNARY_OP(round)
|
||||
BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
|
||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user