Wrapping code to call the custom op. (#225)

* Wrapping code to call the custom op.

* Get the rms example to work.

* Get around rustfmt failing in the CI.

* Fix the rms computation.
This commit is contained in:
Laurent Mazare
2023-07-23 12:31:17 +02:00
committed by GitHub
parent b8a10425ad
commit e449ce53a2
5 changed files with 35 additions and 19 deletions

View File

@ -1,12 +1,12 @@
#include <stdint.h>
#include "reduction_utils.cuh"
template <typename scalar_t>
__device__ void
rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
const scalar_t *__restrict__ input, // [num_tokens, hidden_size]
const scalar_t *__restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size) {
const float epsilon, const uint32_t num_tokens,
const uint32_t hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
@ -22,16 +22,14 @@ rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance));
}
}
extern "C" __global__ void rms_norm_kernel_f32(
extern "C" __global__ void rms_f32(
float *__restrict__ out, // [num_tokens, hidden_size]
const float *__restrict__ input, // [num_tokens, hidden_size]
const float *__restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size) {
rms_norm_kernel(out, input, weight, epsilon, num_tokens, hidden_size);
const float epsilon, const uint32_t num_tokens,
const uint32_t hidden_size) {
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size);
}