Add relu kernel for metal (#1488)

* Add relu kernel for metal

* Copy error messages proposed in #1491

* Revert non relu changes

* Fix name changes

* Fix the last of us (:

* Fix copy and paste mistakes

* Fix typo

* Revert order changes

* Revert order change

* Add deleted functions back

* Run rustfmt
This commit is contained in:
Juarez Bochi
2024-01-10 12:27:17 -05:00
committed by GitHub
parent a897fda74e
commit ae06cb74bb
3 changed files with 14 additions and 2 deletions

View File

@ -675,6 +675,7 @@ impl BackendStorage for MetalStorage {
("uround", DType::F32) => contiguous::round::FLOAT,
("urecip", DType::F32) => contiguous::recip::FLOAT,
("utanh", DType::F32) => contiguous::tanh::FLOAT,
("urelu", DType::F32) => contiguous::relu::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF,
@ -691,6 +692,7 @@ impl BackendStorage for MetalStorage {
("uround", DType::F16) => contiguous::round::HALF,
("urecip", DType::F16) => contiguous::recip::HALF,
("utanh", DType::F16) => contiguous::tanh::HALF,
("urelu", DType::F16) => contiguous::relu::HALF,
(name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
}
@ -721,6 +723,7 @@ impl BackendStorage for MetalStorage {
("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
@ -735,6 +738,7 @@ impl BackendStorage for MetalStorage {
("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF,
(name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")

View File

@ -174,8 +174,8 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
recip
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip
);
}
pub mod binary {

View File

@ -58,6 +58,12 @@ template <typename T> METAL_FUNC T gelu(T x) {
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
template <typename T> METAL_FUNC T relu(T in){
if (in < 0) {
return 0;
}
return in;
}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@ -110,6 +116,7 @@ UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
@ -136,6 +143,7 @@ BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif