Merge pull request #1496 from bayedieng/unary

Implement urecip op for metal backend
This commit is contained in:
Nicolas Patry
2023-12-29 12:20:52 +01:00
committed by GitHub
3 changed files with 8 additions and 3 deletions

View File

@ -652,6 +652,7 @@ impl BackendStorage for MetalStorage {
("uceil", DType::F32) => contiguous::ceil::FLOAT, ("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT,
("urecip", DType::F32) => contiguous::recip::FLOAT,
("utanh", DType::F32) => contiguous::tanh::FLOAT, ("utanh", DType::F32) => contiguous::tanh::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF, ("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF, ("usin", DType::F16) => contiguous::sin::HALF,
@ -666,6 +667,7 @@ impl BackendStorage for MetalStorage {
("uceil", DType::F16) => contiguous::ceil::HALF, ("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF, ("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF, ("uround", DType::F16) => contiguous::round::HALF,
("urecip", DType::F16) => contiguous::recip::HALF,
("utanh", DType::F16) => contiguous::tanh::HALF, ("utanh", DType::F16) => contiguous::tanh::HALF,
(name, dtype) => { (name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")

View File

@ -165,7 +165,7 @@ macro_rules! ops{
} }
pub mod unary { pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh, recip);
} }
pub mod binary { pub mod binary {
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt); ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);

View File

@ -19,7 +19,9 @@ METAL_FUNC uint get_strided_index(
} }
template <typename T> METAL_FUNC T sqr(T in){ return in * in; } template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T recip(T in){ return T(1.0 / in); }
template <typename T> METAL_FUNC T neg(T in){ return -in; } template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T erf(T in){ template <typename T> METAL_FUNC T erf(T in){
float x = (float) in; float x = (float) in;
// constants // constants
@ -57,8 +59,6 @@ template <typename T> METAL_FUNC T gelu(T x) {
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta))); return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
} }
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \ kernel void FN_NAME( \
constant size_t &dim, \ constant size_t &dim, \
@ -108,6 +108,8 @@ UNARY_OP(round)
UNARY_OP(gelu_erf) UNARY_OP(gelu_erf)
UNARY_OP(erf) UNARY_OP(erf)
UNARY_OP(tanh) UNARY_OP(tanh)
UNARY_OP(recip)
UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@ -128,6 +130,7 @@ BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided) UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif #endif