mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Kernel build example (#224)
* Build example kernels. * Add some sample custom kernel. * Get the example kernel to compile. * Add some cuda code. * More cuda custom op. * More cuda custom ops.
This commit is contained in:
1
candle-examples/examples/custom-ops/cuda_kernels.rs
Normal file
1
candle-examples/examples/custom-ops/cuda_kernels.rs
Normal file
@ -0,0 +1 @@
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
|
@ -0,0 +1,37 @@
|
||||
#include "reduction_utils.cuh"
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ void
|
||||
rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
|
||||
const scalar_t *__restrict__ input, // [num_tokens, hidden_size]
|
||||
const scalar_t *__restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
extern "C" __global__ void rms_norm_kernel_f32(
|
||||
float *__restrict__ out, // [num_tokens, hidden_size]
|
||||
const float *__restrict__ input, // [num_tokens, hidden_size]
|
||||
const float *__restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
rms_norm_kernel(out, input, weight, epsilon, num_tokens, hidden_size);
|
||||
}
|
||||
|
@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
template <typename T> __inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template <typename T> __inline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
||||
val = warpReduceSum<T>(val);
|
||||
return val;
|
||||
}
|
65
candle-examples/examples/custom-ops/main.rs
Normal file
65
candle-examples/examples/custom-ops/main.rs
Normal file
@ -0,0 +1,65 @@
|
||||
#![allow(dead_code)]
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cpu_backend;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
struct LayerNorm;
|
||||
|
||||
impl CustomOp1 for LayerNorm {
|
||||
fn name(&self) -> &'static str {
|
||||
"layer-norm"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
let s = s.as_slice::<f32>()?;
|
||||
let _s = match l.contiguous_offsets() {
|
||||
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
Some((o1, o2)) => &s[o1..o2],
|
||||
};
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
s: &candle::CudaStorage,
|
||||
l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
let device = s.device().clone();
|
||||
let s = s.as_cuda_slice::<f32>()?;
|
||||
let s = match l.contiguous_offsets() {
|
||||
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
Some((o1, o2)) => s, // TODO: slice with o1 and o2
|
||||
};
|
||||
let s: std::result::Result<_, candle::cuda_backend::CudaError> =
|
||||
s.try_clone().map_err(|v| v.into());
|
||||
let s = s?;
|
||||
let s = candle::CudaStorage::wrap_cuda_slice(s, device);
|
||||
Ok((s, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
|
||||
println!("{t}");
|
||||
let t = t.custom_op1(LayerNorm)?;
|
||||
println!("{t}");
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user