This commit is contained in:
Nicolas Patry
2023-11-11 18:22:16 +01:00
parent b58b247323
commit 6071797450
3 changed files with 36 additions and 13 deletions

View File

@ -346,9 +346,8 @@ impl BackendStorage for MetalStorage {
("uexp", DType::F32) => contiguous::exp::FLOAT,
("ulog", DType::F32) => contiguous::log::FLOAT,
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
// TODO erf does not exist in metal
("ugelu_erf", DType::F32) => crate::bail!("erf is not implemented in metal"),
("uerf", DType::F32) => crate::bail!("erf is not implemented in metal"),
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("uerf", DType::F32) => contiguous::erf::FLOAT,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT,
@ -360,9 +359,8 @@ impl BackendStorage for MetalStorage {
("uexp", DType::F16) => contiguous::exp::HALF,
("ulog", DType::F16) => contiguous::log::HALF,
("ugelu", DType::F16) => contiguous::gelu::HALF,
// TODO erf does not exist in metal
("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("uerf", DType::F16) => contiguous::erf::HALF,
("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF,
@ -389,9 +387,8 @@ impl BackendStorage for MetalStorage {
("uexp", DType::F32) => strided::exp::FLOAT,
("ulog", DType::F32) => strided::log::FLOAT,
("ugelu", DType::F32) => strided::gelu::FLOAT,
// TODO erf does not exist in metal
("ugelu_erf", DType::F32) => crate::bail!("erf is not implemented in metal"),
("uerf", DType::F32) => crate::bail!("erf is not implemented in metal"),
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uerf", DType::F32) => strided::erf::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
@ -403,9 +400,8 @@ impl BackendStorage for MetalStorage {
("uexp", DType::F16) => strided::exp::HALF,
("ulog", DType::F16) => strided::log::HALF,
("ugelu", DType::F16) => strided::gelu::HALF,
// TODO erf does not exist in metal
("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"),
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uerf", DType::F16) => strided::erf::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("uround", DType::F16) => strided::round::HALF,

View File

@ -152,7 +152,7 @@ macro_rules! ops{
}
pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round);
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf);
}
pub mod binary {
ops!(add, sub, mul, div);

View File

@ -20,7 +20,30 @@ METAL_FUNC uint get_strided_index(
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T erf(T in){
float x = (float) in;
// constants
float a1 = 0.254829592;
float a2 = -0.284496736;
float a3 = 1.421413741;
float a4 = -1.453152027;
float a5 = 1.061405429;
float p = 0.3275911;
// Save the sign of x
int sign = 1;
if (x < 0)
sign = -1;
x = fabs(x);
// A&S formula 7.1.26
float t = 1.0/(1.0 + p*x);
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
return (T) sign*y;
}
template <typename T> METAL_FUNC T id(T in){ return in; }
template <typename T> METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; }
template <typename T> METAL_FUNC T gelu(T x){
T x_sq = x * x;
T x_cube = x_sq * x;
@ -77,6 +100,8 @@ UNARY_OP(gelu)
UNARY_OP(ceil)
UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@ -94,6 +119,8 @@ BFLOAT_UNARY_OP(gelu)
BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif