diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index f12463a4..bab44a05 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -750,6 +750,64 @@ pub fn call_last_softmax( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_rms_norm( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + elements_to_sum: usize, + eps: f32, + input: &Buffer, + input_offset: usize, + alpha: &Buffer, + alpha_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + length, + elements_to_sum, + (input, input_offset), + output, + (alpha, alpha_offset), + eps + ) + ); + + let out_length = length / elements_to_sum; + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub fn call_affine( device: &Device, diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 93dac662..3c3cbc14 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -260,6 +260,59 @@ kernel void NAME( } \ } \ +#define RMSNORM(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + device const T *alpha, \ + constant float &eps, \ + \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = 0; \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t idx = start_idx + tid; \ + \ + \ + float tmp = 0; \ + while (idx < stop_idx) { \ + tmp = tmp + float(src[idx]) * float(src[idx]); \ + idx += block_dim; \ + } \ + shared_memory[tid] = tmp; \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + } \ + \ + /* wait for shared_memory[0] to be filled */ \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); \ + float inv_norm = 1.0f / norm; \ + idx = start_idx + tid; \ + while (idx < stop_idx) { \ + float val = float(src[idx]) * inv_norm; \ + if (alpha != nullptr) { \ + val *= float(alpha[idx - start_idx]); \ + } \ + dst[idx] = T(val); \ + idx += block_dim; \ + } \ +} \ + REDUCE(x + y, fast_sum_f32_strided, float, 0) REDUCE(x + y, fast_sum_u32_strided, uint, 0) REDUCE(x + y, fast_sum_f16_strided, half, 0) @@ -286,6 +339,8 @@ ARGMAX(fast_argmax_u8_strided, uint8_t, 0) SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) +RMSNORM(rmsnorm_f32, float) +RMSNORM(rmsnorm_f16, half) #if __METAL_VERSION__ >= 220 REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) @@ -303,4 +358,5 @@ REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF) ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) +RMSNORM(rmsnorm_bf16, bfloat) #endif diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index d725bdc2..1dac8c3b 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -236,7 +236,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { layout.start_offset() * storage.dtype().size_in_bytes(), &output, ) - .unwrap(); + .map_err(candle::Error::wrap)?; let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype()); Ok((newstorage, layout.shape().clone())) @@ -383,6 +383,51 @@ impl candle::CustomOp2 for RmsNorm { }; Ok((dst, l1.shape().clone())) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + s1: &candle::MetalStorage, + l1: &Layout, + s2: &candle::MetalStorage, + l2: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + let device = s1.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + let name = match (s1.dtype(), s2.dtype()) { + (DType::F32, DType::F32) => "rmsnorm_f32", + (DType::F16, DType::F16) => "rmsnorm_f16", + (DType::BF16, DType::BF16) => "rmsnorm_bf16", + (dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"), + }; + + if !(l1.is_contiguous() && l2.is_contiguous()) { + candle::bail!("Non contiguous rmsnorm is not implemented"); + } + + let last_dim = l1.dims()[l1.shape().rank() - 1]; + let elem_count = l1.shape().elem_count(); + let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?; + candle_metal_kernels::call_rms_norm( + device.metal_device(), + &command_buffer, + kernels, + name, + elem_count, + last_dim, + self.eps, + s1.buffer(), + l1.start_offset() * s1.dtype().size_in_bytes(), + s2.buffer(), + l2.start_offset() * s2.dtype().size_in_bytes(), + &output, + ) + .map_err(candle::Error::wrap)?; + let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype()); + Ok((newstorage, l1.shape().clone())) + } } pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result {