diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index b3116c86..eadbf1f1 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -346,9 +346,8 @@ impl BackendStorage for MetalStorage { ("uexp", DType::F32) => contiguous::exp::FLOAT, ("ulog", DType::F32) => contiguous::log::FLOAT, ("ugelu", DType::F32) => contiguous::gelu::FLOAT, - // TODO erf does not exist in metal - ("ugelu_erf", DType::F32) => crate::bail!("erf is not implemented in metal"), - ("uerf", DType::F32) => crate::bail!("erf is not implemented in metal"), + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("uerf", DType::F32) => contiguous::erf::FLOAT, ("uceil", DType::F32) => contiguous::ceil::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT, @@ -360,9 +359,8 @@ impl BackendStorage for MetalStorage { ("uexp", DType::F16) => contiguous::exp::HALF, ("ulog", DType::F16) => contiguous::log::HALF, ("ugelu", DType::F16) => contiguous::gelu::HALF, - // TODO erf does not exist in metal - ("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"), - ("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"), + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("uerf", DType::F16) => contiguous::erf::HALF, ("uceil", DType::F16) => contiguous::ceil::HALF, ("ufloor", DType::F16) => contiguous::floor::HALF, ("uround", DType::F16) => contiguous::round::HALF, @@ -389,9 +387,8 @@ impl BackendStorage for MetalStorage { ("uexp", DType::F32) => strided::exp::FLOAT, ("ulog", DType::F32) => strided::log::FLOAT, ("ugelu", DType::F32) => strided::gelu::FLOAT, - // TODO erf does not exist in metal - ("ugelu_erf", DType::F32) => crate::bail!("erf is not implemented in metal"), - ("uerf", DType::F32) => crate::bail!("erf is not implemented in metal"), + ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, + ("uerf", DType::F32) => strided::erf::FLOAT, ("uceil", DType::F32) => strided::ceil::FLOAT, ("ufloor", DType::F32) => strided::floor::FLOAT, ("uround", DType::F32) => strided::round::FLOAT, @@ -403,9 +400,8 @@ impl BackendStorage for MetalStorage { ("uexp", DType::F16) => strided::exp::HALF, ("ulog", DType::F16) => strided::log::HALF, ("ugelu", DType::F16) => strided::gelu::HALF, - // TODO erf does not exist in metal - ("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"), - ("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"), + ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, + ("uerf", DType::F16) => strided::erf::HALF, ("uceil", DType::F16) => strided::ceil::HALF, ("ufloor", DType::F16) => strided::floor::HALF, ("uround", DType::F16) => strided::round::HALF, diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 2cadc8c6..afbcbff7 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -152,7 +152,7 @@ macro_rules! ops{ } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round); + ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf); } pub mod binary { ops!(add, sub, mul, div); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index ae690ca4..9f614b30 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -20,7 +20,30 @@ METAL_FUNC uint get_strided_index( template METAL_FUNC T sqr(T in){ return in * in; } template METAL_FUNC T neg(T in){ return -in; } +template METAL_FUNC T erf(T in){ + float x = (float) in; + // constants + float a1 = 0.254829592; + float a2 = -0.284496736; + float a3 = 1.421413741; + float a4 = -1.453152027; + float a5 = 1.061405429; + float p = 0.3275911; + + // Save the sign of x + int sign = 1; + if (x < 0) + sign = -1; + x = fabs(x); + + // A&S formula 7.1.26 + float t = 1.0/(1.0 + p*x); + float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x); + + return (T) sign*y; +} template METAL_FUNC T id(T in){ return in; } +template METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; } template METAL_FUNC T gelu(T x){ T x_sq = x * x; T x_cube = x_sq * x; @@ -77,6 +100,8 @@ UNARY_OP(gelu) UNARY_OP(ceil) UNARY_OP(floor) UNARY_OP(round) +UNARY_OP(gelu_erf) +UNARY_OP(erf) UNARY(id, float, copy_float, copy_float_strided) UNARY(id, half, copy_half, copy_half_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) @@ -94,6 +119,8 @@ BFLOAT_UNARY_OP(gelu) BFLOAT_UNARY_OP(ceil) BFLOAT_UNARY_OP(floor) BFLOAT_UNARY_OP(round) +BFLOAT_UNARY_OP(gelu_erf) +BFLOAT_UNARY_OP(erf) UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif