diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 597c2f01..72b15006 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -293,6 +293,12 @@ impl BackendStorage for MetalStorage { ("uneg", DType::F32) => contiguous::neg::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT, ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + // TODO erf does not exist in metal + ("ugelu_erf", DType::F32) => contiguous::gelu::FLOAT, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("uround", DType::F32) => contiguous::round::FLOAT, (name, dtype) => todo!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -519,7 +525,6 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let device = self.device(); let mut buffer = device.new_buffer(dst_el, dtype); - let out = self.to_cpu_storage().unwrap(); let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", (left, right) => todo!("index select metal {left:?} {right:?}"), @@ -690,6 +695,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::U32 => candle_metal_kernels::unary::strided::copy::U32, dtype => todo!("copy_strided not implemented for {dtype:?}"), }; candle_metal_kernels::call_unary_strided( diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 7288216a..5a50d46f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -112,13 +112,7 @@ macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { - #[derive(Clone, Copy)] pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } $( pub mod $name { use super::Kernel; @@ -127,16 +121,17 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_float"); + pub const HALF: Kernel = Kernel("copy_half"); + pub const BFLOAT: Kernel = Kernel("copy_bfloat"); + pub const U32: Kernel = Kernel("copy_u32"); + } } pub mod strided { - #[derive(Clone, Copy)] pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } $( pub mod $name { use super::Kernel; @@ -145,12 +140,19 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_float_strided"); + pub const HALF: Kernel = Kernel("copy_half_strided"); + pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided"); + pub const U32: Kernel = Kernel("copy_u32_strided"); + } } }; } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, copy, log); + ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round); } pub mod binary { ops!(add, sub, mul, div); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index eb6424e8..ae690ca4 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -1,4 +1,7 @@ #include +#include +# +using namespace metal; METAL_FUNC uint get_strided_index( uint idx, @@ -18,9 +21,15 @@ METAL_FUNC uint get_strided_index( template METAL_FUNC T sqr(T in){ return in * in; } template METAL_FUNC T neg(T in){ return -in; } template METAL_FUNC T id(T in){ return in; } +template METAL_FUNC T gelu(T x){ + T x_sq = x * x; + T x_cube = x_sq * x; + T alpha = x + static_cast(0.044715) * x_cube; + T beta = (static_cast(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); + return static_cast(0.5) * x * (static_cast(1.0) + tanh(beta)); +} -using namespace metal; #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ @@ -64,8 +73,14 @@ UNARY_OP(sqrt) UNARY_OP(neg) UNARY_OP(exp) UNARY_OP(log) +UNARY_OP(gelu) +UNARY_OP(ceil) +UNARY_OP(floor) +UNARY_OP(round) UNARY(id, float, copy_float, copy_float_strided) UNARY(id, half, copy_half, copy_half_strided) +UNARY(id, uint8_t, copy_u8, copy_u8_strided) +UNARY(id, uint32_t, copy_u32, copy_u32_strided) #if __METAL_VERSION__ >= 310 BFLOAT_UNARY_OP(cos) @@ -75,6 +90,10 @@ BFLOAT_UNARY_OP(sqrt) BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(log) +BFLOAT_UNARY_OP(gelu) +BFLOAT_UNARY_OP(ceil) +BFLOAT_UNARY_OP(floor) +BFLOAT_UNARY_OP(round) UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif