mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Wrapping code to call the custom op. (#225)
* Wrapping code to call the custom op. * Get the rms example to work. * Get around rustfmt failing in the CI. * Fix the rms computation.
This commit is contained in:
@ -2,6 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{
|
||||
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
@ -101,7 +102,7 @@ impl std::ops::Deref for CudaDevice {
|
||||
}
|
||||
}
|
||||
|
||||
trait WrapErr<O> {
|
||||
pub trait WrapErr<O> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||
}
|
||||
|
||||
@ -171,7 +172,7 @@ impl CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
if !self.has_func(module_name, module_name) {
|
||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||
// done once per kernel name.
|
||||
|
@ -96,6 +96,7 @@ impl KernelDirectories {
|
||||
.file_stem()
|
||||
.context("empty stem")?
|
||||
.to_string_lossy();
|
||||
file.write_all(b"#[rustfmt::skip]\n")?;
|
||||
let const_definition = format!(
|
||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
|
||||
name.to_uppercase().replace('.', "_"),
|
||||
|
@ -1 +1,2 @@
|
||||
#[rustfmt::skip]
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
|
||||
|
@ -1,12 +1,12 @@
|
||||
#include <stdint.h>
|
||||
#include "reduction_utils.cuh"
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ void
|
||||
rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
|
||||
const scalar_t *__restrict__ input, // [num_tokens, hidden_size]
|
||||
const scalar_t *__restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
const float epsilon, const uint32_t num_tokens,
|
||||
const uint32_t hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
@ -22,16 +22,14 @@ rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance));
|
||||
}
|
||||
}
|
||||
extern "C" __global__ void rms_norm_kernel_f32(
|
||||
extern "C" __global__ void rms_f32(
|
||||
float *__restrict__ out, // [num_tokens, hidden_size]
|
||||
const float *__restrict__ input, // [num_tokens, hidden_size]
|
||||
const float *__restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
rms_norm_kernel(out, input, weight, epsilon, num_tokens, hidden_size);
|
||||
const float epsilon, const uint32_t num_tokens,
|
||||
const uint32_t hidden_size) {
|
||||
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size);
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,8 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod cuda_kernels;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
@ -40,17 +42,30 @@ impl CustomOp1 for LayerNorm {
|
||||
s: &candle::CudaStorage,
|
||||
l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
let device = s.device().clone();
|
||||
use candle::cuda_backend::{cudarc, WrapErr};
|
||||
use cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
let (d1, d2) = l.shape().dims2()?;
|
||||
let d1 = d1 as u32;
|
||||
let d2 = d2 as u32;
|
||||
let dev = s.device().clone();
|
||||
let s = s.as_cuda_slice::<f32>()?;
|
||||
let s = match l.contiguous_offsets() {
|
||||
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
Some((o1, o2)) => s, // TODO: slice with o1 and o2
|
||||
Some((o1, o2)) => s.slice(o1..o2),
|
||||
};
|
||||
let s: std::result::Result<_, candle::cuda_backend::CudaError> =
|
||||
s.try_clone().map_err(|v| v.into());
|
||||
let s = s?;
|
||||
let s = candle::CudaStorage::wrap_cuda_slice(s, device);
|
||||
Ok((s, l.shape().clone()))
|
||||
let elem_count = l.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||
let params = (&dst, &s, 1e-5f32, d1, d2);
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (d1, 1, 1),
|
||||
block_dim: (d2, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user