diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 2703b1c5..9abbda6e 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -652,6 +652,7 @@ impl BackendStorage for MetalStorage { ("uceil", DType::F32) => contiguous::ceil::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT, + ("urecip", DType::F32) => contiguous::recip::FLOAT, ("utanh", DType::F32) => contiguous::tanh::FLOAT, ("ucos", DType::F16) => contiguous::cos::HALF, ("usin", DType::F16) => contiguous::sin::HALF, @@ -666,6 +667,7 @@ impl BackendStorage for MetalStorage { ("uceil", DType::F16) => contiguous::ceil::HALF, ("ufloor", DType::F16) => contiguous::floor::HALF, ("uround", DType::F16) => contiguous::round::HALF, + ("urecip", DType::F16) => contiguous::recip::HALF, ("utanh", DType::F16) => contiguous::tanh::HALF, (name, dtype) => { crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index dd97a86d..94479882 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -165,7 +165,7 @@ macro_rules! ops{ } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); + ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh, recip); } pub mod binary { ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 04fa37a9..826b9045 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -19,7 +19,9 @@ METAL_FUNC uint get_strided_index( } template METAL_FUNC T sqr(T in){ return in * in; } +template METAL_FUNC T recip(T in){ return T(1.0 / in); } template METAL_FUNC T neg(T in){ return -in; } + template METAL_FUNC T erf(T in){ float x = (float) in; // constants @@ -57,8 +59,6 @@ template METAL_FUNC T gelu(T x) { return static_cast(0.5) * x * (static_cast(1.0) + T(tanh(beta))); } - - #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ constant size_t &dim, \ @@ -108,6 +108,8 @@ UNARY_OP(round) UNARY_OP(gelu_erf) UNARY_OP(erf) UNARY_OP(tanh) +UNARY_OP(recip) + UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) @@ -128,6 +130,7 @@ BFLOAT_UNARY_OP(round) BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(tanh) +BFLOAT_UNARY_OP(recip) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) #endif