From e449ce53a2f3c85f23ca0f2e7d557a0d0003e0ca Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 23 Jul 2023 12:31:17 +0200 Subject: [PATCH] Wrapping code to call the custom op. (#225) * Wrapping code to call the custom op. * Get the rms example to work. * Get around rustfmt failing in the CI. * Fix the rms computation. --- candle-core/src/cuda_backend.rs | 5 ++-- candle-examples/build.rs | 1 + .../examples/custom-ops/cuda_kernels.rs | 1 + .../custom-ops/kernels/layernorm_kernels.cu | 18 +++++------- candle-examples/examples/custom-ops/main.rs | 29 ++++++++++++++----- 5 files changed, 35 insertions(+), 19 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index d2cc3e41..5e362041 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -2,6 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; use candle_kernels as kernels; +pub use cudarc; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, @@ -101,7 +102,7 @@ impl std::ops::Deref for CudaDevice { } } -trait WrapErr { +pub trait WrapErr { fn w(self) -> std::result::Result; } @@ -171,7 +172,7 @@ impl CudaDevice { }) } - fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result { + pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result { if !self.has_func(module_name, module_name) { // Leaking the string here is a bit sad but we need a &'static str and this is only // done once per kernel name. diff --git a/candle-examples/build.rs b/candle-examples/build.rs index 7f69fa77..119e5b0a 100644 --- a/candle-examples/build.rs +++ b/candle-examples/build.rs @@ -96,6 +96,7 @@ impl KernelDirectories { .file_stem() .context("empty stem")? .to_string_lossy(); + file.write_all(b"#[rustfmt::skip]\n")?; let const_definition = format!( r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#, name.to_uppercase().replace('.', "_"), diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs index 07d18342..0bee73aa 100644 --- a/candle-examples/examples/custom-ops/cuda_kernels.rs +++ b/candle-examples/examples/custom-ops/cuda_kernels.rs @@ -1 +1,2 @@ +#[rustfmt::skip] pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx")); diff --git a/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu index 07ab8639..a0836392 100644 --- a/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu +++ b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu @@ -1,12 +1,12 @@ +#include #include "reduction_utils.cuh" template __device__ void rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size] const scalar_t *__restrict__ input, // [num_tokens, hidden_size] - const scalar_t *__restrict__ weight, // [hidden_size] - const float epsilon, const int num_tokens, - const int hidden_size) { + const float epsilon, const uint32_t num_tokens, + const uint32_t hidden_size) { __shared__ float s_variance; float variance = 0.0f; @@ -22,16 +22,14 @@ rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size] for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)input[blockIdx.x * hidden_size + idx]; - out[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)); } } -extern "C" __global__ void rms_norm_kernel_f32( +extern "C" __global__ void rms_f32( float *__restrict__ out, // [num_tokens, hidden_size] const float *__restrict__ input, // [num_tokens, hidden_size] - const float *__restrict__ weight, // [hidden_size] - const float epsilon, const int num_tokens, - const int hidden_size) { - rms_norm_kernel(out, input, weight, epsilon, num_tokens, hidden_size); + const float epsilon, const uint32_t num_tokens, + const uint32_t hidden_size) { + rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size); } diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs index adc7abd7..9c917cca 100644 --- a/candle-examples/examples/custom-ops/main.rs +++ b/candle-examples/examples/custom-ops/main.rs @@ -4,6 +4,8 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; +mod cuda_kernels; + use clap::Parser; use candle::backend::BackendStorage; @@ -40,17 +42,30 @@ impl CustomOp1 for LayerNorm { s: &candle::CudaStorage, l: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { - let device = s.device().clone(); + use candle::cuda_backend::{cudarc, WrapErr}; + use cudarc::driver::{LaunchAsync, LaunchConfig}; + let (d1, d2) = l.shape().dims2()?; + let d1 = d1 as u32; + let d2 = d2 as u32; + let dev = s.device().clone(); let s = s.as_cuda_slice::()?; let s = match l.contiguous_offsets() { None => Err(Error::Wrapped("input has to be contiguous".into()))?, - Some((o1, o2)) => s, // TODO: slice with o1 and o2 + Some((o1, o2)) => s.slice(o1..o2), }; - let s: std::result::Result<_, candle::cuda_backend::CudaError> = - s.try_clone().map_err(|v| v.into()); - let s = s?; - let s = candle::CudaStorage::wrap_cuda_slice(s, device); - Ok((s, l.shape().clone())) + let elem_count = l.shape().elem_count(); + let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?; + let params = (&dst, &s, 1e-5f32, d1, d2); + let cfg = LaunchConfig { + grid_dim: (d1, 1, 1), + block_dim: (d2, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { func.launch(cfg, params) }.w()?; + + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev); + Ok((dst, l.shape().clone())) } }