mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels. * Dedicated layer norm op. * Add the slower variant. * Plug the cuda implementation. * Add the metal variant. * Add a dedicated test. * Bugfix.
This commit is contained in:
@ -739,6 +739,69 @@ pub fn call_rms_norm(
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_layer_norm(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
elements_to_sum: usize,
|
||||
eps: f32,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
alpha: &Buffer,
|
||||
alpha_offset: usize,
|
||||
beta: &Buffer,
|
||||
beta_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
length,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output,
|
||||
(alpha, alpha_offset),
|
||||
(beta, beta_offset),
|
||||
eps
|
||||
)
|
||||
);
|
||||
|
||||
let out_length = length / elements_to_sum;
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
elements_to_sum as u64,
|
||||
)
|
||||
.next_power_of_two();
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
|
@ -353,6 +353,65 @@ METAL_FUNC void rmsnorm(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void layernorm(
|
||||
constant size_t & src_numel,
|
||||
constant size_t & el_to_sum_per_block,
|
||||
device const T * src,
|
||||
device T * dst,
|
||||
device const T * alpha,
|
||||
device const T * beta,
|
||||
constant float & eps,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup float * shared_memory
|
||||
) {
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
float tmp1 = 0;
|
||||
float tmp2 = 0;
|
||||
while (idx < stop_idx) {
|
||||
tmp1 += float(src[idx]);
|
||||
tmp2 += float(src[idx]) * float(src[idx]);
|
||||
idx += block_dim;
|
||||
}
|
||||
shared_memory[tid] = tmp1;
|
||||
shared_memory[tid + block_dim] = tmp2;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
|
||||
shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
/* wait for shared_memory[0] to be filled */
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = shared_memory[0] / float(el_to_sum_per_block);
|
||||
float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean;
|
||||
float inv_norm = 1.0f / sqrt(var + eps);
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
float val = (float(src[idx]) - mean) * inv_norm;
|
||||
if (alpha != nullptr) {
|
||||
val *= float(alpha[idx - start_idx]);
|
||||
}
|
||||
if (beta != nullptr) {
|
||||
val += float(beta[idx - start_idx]);
|
||||
}
|
||||
dst[idx] = T(val);
|
||||
idx += block_dim;
|
||||
}
|
||||
}
|
||||
|
||||
#define RMSNORM(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
@ -371,6 +430,25 @@ kernel void NAME( \
|
||||
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
|
||||
} \
|
||||
|
||||
#define LAYERNORM(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
device const T *alpha, \
|
||||
device const T *beta, \
|
||||
constant float &eps, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = 0; \
|
||||
layernorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void ropei(
|
||||
constant size_t &bh,
|
||||
@ -511,6 +589,8 @@ SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
RMSNORM(rmsnorm_f32, float)
|
||||
RMSNORM(rmsnorm_f16, half)
|
||||
LAYERNORM(layernorm_f32, float)
|
||||
LAYERNORM(layernorm_f16, half)
|
||||
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
|
||||
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
|
||||
|
||||
@ -535,5 +615,6 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
RMSNORM(rmsnorm_bf16, bfloat)
|
||||
LAYERNORM(layernorm_bf16, bfloat)
|
||||
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user