mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
RmsNorm kernel for metal. (#1895)
* RmsNorm kernel for metal. * Wrapper for the metal kernel. * Get the ops to actually work. * Fix, get the tests to pass.
This commit is contained in:
@ -750,6 +750,64 @@ pub fn call_last_softmax(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_rms_norm(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
elements_to_sum: usize,
|
||||
eps: f32,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
alpha: &Buffer,
|
||||
alpha_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
length,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output,
|
||||
(alpha, alpha_offset),
|
||||
eps
|
||||
)
|
||||
);
|
||||
|
||||
let out_length = length / elements_to_sum;
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
elements_to_sum as u64,
|
||||
)
|
||||
.next_power_of_two();
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_affine(
|
||||
device: &Device,
|
||||
|
Reference in New Issue
Block a user