mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Wrapping code to call the custom op. (#225)
* Wrapping code to call the custom op. * Get the rms example to work. * Get around rustfmt failing in the CI. * Fix the rms computation.
This commit is contained in:
@ -1,12 +1,12 @@
|
||||
#include <stdint.h>
|
||||
#include "reduction_utils.cuh"
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ void
|
||||
rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
|
||||
const scalar_t *__restrict__ input, // [num_tokens, hidden_size]
|
||||
const scalar_t *__restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
const float epsilon, const uint32_t num_tokens,
|
||||
const uint32_t hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
@ -22,16 +22,14 @@ rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance));
|
||||
}
|
||||
}
|
||||
extern "C" __global__ void rms_norm_kernel_f32(
|
||||
extern "C" __global__ void rms_f32(
|
||||
float *__restrict__ out, // [num_tokens, hidden_size]
|
||||
const float *__restrict__ input, // [num_tokens, hidden_size]
|
||||
const float *__restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
rms_norm_kernel(out, input, weight, epsilon, num_tokens, hidden_size);
|
||||
const float epsilon, const uint32_t num_tokens,
|
||||
const uint32_t hidden_size) {
|
||||
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user