mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Kernel build example (#224)
* Build example kernels. * Add some sample custom kernel. * Get the example kernel to compile. * Add some cuda code. * More cuda custom op. * More cuda custom ops.
This commit is contained in:
@ -771,6 +771,50 @@ pub struct CudaStorage {
|
|||||||
device: CudaDevice,
|
device: CudaDevice,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait CudaDType: Sized {
|
||||||
|
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
|
||||||
|
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! cuda_dtype {
|
||||||
|
($ty:ty, $dtype:ident) => {
|
||||||
|
impl CudaDType for $ty {
|
||||||
|
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>> {
|
||||||
|
match &s.slice {
|
||||||
|
CudaStorageSlice::$dtype(data) => Ok(&data),
|
||||||
|
_ => Err(crate::Error::UnexpectedDType {
|
||||||
|
expected: DType::$dtype,
|
||||||
|
got: s.dtype(),
|
||||||
|
msg: "unexpected dtype",
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
|
||||||
|
let slice = CudaStorageSlice::$dtype(slice);
|
||||||
|
CudaStorage { slice, device }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
cuda_dtype!(u8, U8);
|
||||||
|
cuda_dtype!(u32, U32);
|
||||||
|
cuda_dtype!(f16, F16);
|
||||||
|
cuda_dtype!(bf16, BF16);
|
||||||
|
cuda_dtype!(f32, F32);
|
||||||
|
cuda_dtype!(f64, F64);
|
||||||
|
|
||||||
|
impl CudaStorage {
|
||||||
|
pub fn wrap_cuda_slice<T: CudaDType>(slice: CudaSlice<T>, device: CudaDevice) -> CudaStorage {
|
||||||
|
T::wrap_cuda_slice(slice, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
|
||||||
|
T::as_cuda_slice(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn gemm_config<T>(
|
fn gemm_config<T>(
|
||||||
alpha: T,
|
alpha: T,
|
||||||
beta: T,
|
beta: T,
|
||||||
|
@ -30,6 +30,9 @@ tracing-chrome = { workspace = true }
|
|||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
wav = { workspace = true }
|
wav = { workspace = true }
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
|
231
candle-examples/build.rs
Normal file
231
candle-examples/build.rs
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::io::Write;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
struct KernelDirectories {
|
||||||
|
kernel_dir: &'static str,
|
||||||
|
rust_target: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||||
|
kernel_dir: "examples/custom-ops/kernels/",
|
||||||
|
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
||||||
|
}];
|
||||||
|
|
||||||
|
impl KernelDirectories {
|
||||||
|
fn maybe_build_ptx(
|
||||||
|
&self,
|
||||||
|
cu_file: &std::path::Path,
|
||||||
|
ptx_file: &std::path::Path,
|
||||||
|
compute_cap: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
let should_compile = if ptx_file.exists() {
|
||||||
|
let ptx_modified = ptx_file.metadata()?.modified()?;
|
||||||
|
let cu_modified = cu_file.metadata()?.modified()?;
|
||||||
|
cu_modified.duration_since(ptx_modified).is_ok()
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
};
|
||||||
|
if should_compile {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
let mut command = std::process::Command::new("nvcc");
|
||||||
|
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
|
||||||
|
command
|
||||||
|
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||||
|
.arg("--ptx")
|
||||||
|
.args(["--default-stream", "per-thread"])
|
||||||
|
.args(["--output-directory", out_dir.to_str().unwrap()])
|
||||||
|
.arg(format!("-I/{}", self.kernel_dir))
|
||||||
|
.arg(cu_file);
|
||||||
|
let output = command
|
||||||
|
.spawn()
|
||||||
|
.context("failed spawning nvcc")?
|
||||||
|
.wait_with_output()?;
|
||||||
|
if !output.status.success() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||||
|
String::from_utf8_lossy(&output.stdout),
|
||||||
|
String::from_utf8_lossy(&output.stderr)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
std::fs::OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.write(true)
|
||||||
|
.open(ptx_file)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
|
||||||
|
println!("cargo:rerun-if-changed={}", self.kernel_dir);
|
||||||
|
let kernel_dir = PathBuf::from(self.kernel_dir);
|
||||||
|
let out_dir = out_dir.join(self.kernel_dir);
|
||||||
|
if !out_dir.exists() {
|
||||||
|
std::fs::create_dir_all(&out_dir)?;
|
||||||
|
}
|
||||||
|
let mut cu_files = vec![];
|
||||||
|
let mut cuh_files = vec![];
|
||||||
|
for file in std::fs::read_dir(kernel_dir)?.flatten() {
|
||||||
|
let file = file.path();
|
||||||
|
match file.extension().and_then(|v| v.to_str()) {
|
||||||
|
Some("cu") => cu_files.push(file),
|
||||||
|
Some("cuh") => cuh_files.push(file),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut ptx_paths = vec![];
|
||||||
|
for cu_file in cu_files.iter() {
|
||||||
|
let file_stem = cu_file
|
||||||
|
.file_stem()
|
||||||
|
.with_context(|| format!("no stem {cu_file:?}"))?;
|
||||||
|
let file_stem = file_stem.to_string_lossy().into_owned();
|
||||||
|
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
|
||||||
|
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
|
||||||
|
ptx_paths.push(ptx_file);
|
||||||
|
}
|
||||||
|
|
||||||
|
let regenerate_rs_file = true;
|
||||||
|
if regenerate_rs_file {
|
||||||
|
let mut file = std::fs::File::create(self.rust_target)?;
|
||||||
|
for ptx_path in ptx_paths {
|
||||||
|
let name = ptx_path
|
||||||
|
.file_stem()
|
||||||
|
.context("empty stem")?
|
||||||
|
.to_string_lossy();
|
||||||
|
let const_definition = format!(
|
||||||
|
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
|
||||||
|
name.to_uppercase().replace('.', "_"),
|
||||||
|
self.kernel_dir,
|
||||||
|
);
|
||||||
|
file.write_all(const_definition.as_bytes())?;
|
||||||
|
file.write_all(b"\n")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
|
||||||
|
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
|
||||||
|
let out_dir = PathBuf::from(out_dir);
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
set_cuda_include_dir()?;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
let compute_cap = compute_cap()?;
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
let compute_cap = 0;
|
||||||
|
for d in DIRS {
|
||||||
|
d.process(&out_dir, compute_cap)?
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_cuda_include_dir() -> Result<()> {
|
||||||
|
// NOTE: copied from cudarc build.rs.
|
||||||
|
let env_vars = [
|
||||||
|
"CUDA_PATH",
|
||||||
|
"CUDA_ROOT",
|
||||||
|
"CUDA_TOOLKIT_ROOT_DIR",
|
||||||
|
"CUDNN_LIB",
|
||||||
|
];
|
||||||
|
let env_vars = env_vars
|
||||||
|
.into_iter()
|
||||||
|
.map(std::env::var)
|
||||||
|
.filter_map(Result::ok)
|
||||||
|
.map(Into::<PathBuf>::into);
|
||||||
|
|
||||||
|
let roots = [
|
||||||
|
"/usr",
|
||||||
|
"/usr/local/cuda",
|
||||||
|
"/opt/cuda",
|
||||||
|
"/usr/lib/cuda",
|
||||||
|
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||||
|
"C:/CUDA",
|
||||||
|
];
|
||||||
|
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||||
|
let root = env_vars
|
||||||
|
.chain(roots)
|
||||||
|
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||||
|
.context("cannot find include/cuda.h")?;
|
||||||
|
println!(
|
||||||
|
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||||
|
root.join("include").display()
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
fn compute_cap() -> Result<usize> {
|
||||||
|
// Grab compute code from nvidia-smi
|
||||||
|
let mut compute_cap = {
|
||||||
|
let out = std::process::Command::new("nvidia-smi")
|
||||||
|
.arg("--query-gpu=compute_cap")
|
||||||
|
.arg("--format=csv")
|
||||||
|
.output()
|
||||||
|
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||||
|
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||||
|
let mut lines = out.lines();
|
||||||
|
assert_eq!(
|
||||||
|
lines.next().context("missing line in stdout")?,
|
||||||
|
"compute_cap"
|
||||||
|
);
|
||||||
|
let cap = lines
|
||||||
|
.next()
|
||||||
|
.context("missing line in stdout")?
|
||||||
|
.replace('.', "");
|
||||||
|
cap.parse::<usize>()
|
||||||
|
.with_context(|| format!("cannot parse as int {cap}"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Grab available GPU codes from nvcc and select the highest one
|
||||||
|
let max_nvcc_code = {
|
||||||
|
let out = std::process::Command::new("nvcc")
|
||||||
|
.arg("--list-gpu-code")
|
||||||
|
.output()
|
||||||
|
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||||
|
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||||
|
|
||||||
|
let out = out.lines().collect::<Vec<&str>>();
|
||||||
|
let mut codes = Vec::with_capacity(out.len());
|
||||||
|
for code in out {
|
||||||
|
let code = code.split('_').collect::<Vec<&str>>();
|
||||||
|
if !code.is_empty() && code.contains(&"sm") {
|
||||||
|
if let Ok(num) = code[1].parse::<usize>() {
|
||||||
|
codes.push(num);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
codes.sort();
|
||||||
|
if !codes.contains(&compute_cap) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
*codes.last().unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||||
|
// then choose the highest gpu code in nvcc
|
||||||
|
if compute_cap > max_nvcc_code {
|
||||||
|
println!(
|
||||||
|
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
||||||
|
);
|
||||||
|
compute_cap = max_nvcc_code;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||||
|
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||||
|
compute_cap = compute_cap_str
|
||||||
|
.parse::<usize>()
|
||||||
|
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
||||||
|
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||||
|
}
|
||||||
|
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||||
|
Ok(compute_cap)
|
||||||
|
}
|
1
candle-examples/examples/custom-ops/cuda_kernels.rs
Normal file
1
candle-examples/examples/custom-ops/cuda_kernels.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
|
@ -0,0 +1,37 @@
|
|||||||
|
#include "reduction_utils.cuh"
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ void
|
||||||
|
rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
|
||||||
|
const scalar_t *__restrict__ input, // [num_tokens, hidden_size]
|
||||||
|
const scalar_t *__restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon, const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
|
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||||
|
variance += x * x;
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
|
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||||
|
out[blockIdx.x * hidden_size + idx] =
|
||||||
|
((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
extern "C" __global__ void rms_norm_kernel_f32(
|
||||||
|
float *__restrict__ out, // [num_tokens, hidden_size]
|
||||||
|
const float *__restrict__ input, // [num_tokens, hidden_size]
|
||||||
|
const float *__restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon, const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
rms_norm_kernel(out, input, weight, epsilon, num_tokens, hidden_size);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,46 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
||||||
|
* Copyright (c) 2023, The vLLM team.
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
template <typename T> __inline__ __device__ T warpReduceSum(T val) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1)
|
||||||
|
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Calculate the sum of all elements in a block */
|
||||||
|
template <typename T> __inline__ __device__ T blockReduceSum(T val) {
|
||||||
|
static __shared__ T shared[32];
|
||||||
|
int lane = threadIdx.x & 0x1f;
|
||||||
|
int wid = threadIdx.x >> 5;
|
||||||
|
|
||||||
|
val = warpReduceSum<T>(val);
|
||||||
|
|
||||||
|
if (lane == 0)
|
||||||
|
shared[wid] = val;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||||
|
// blockDim.x is not divided by 32
|
||||||
|
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
||||||
|
val = warpReduceSum<T>(val);
|
||||||
|
return val;
|
||||||
|
}
|
65
candle-examples/examples/custom-ops/main.rs
Normal file
65
candle-examples/examples/custom-ops/main.rs
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
#![allow(dead_code)]
|
||||||
|
#![allow(unused)]
|
||||||
|
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::backend::BackendStorage;
|
||||||
|
use candle::cpu_backend;
|
||||||
|
use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LayerNorm;
|
||||||
|
|
||||||
|
impl CustomOp1 for LayerNorm {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"layer-norm"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||||
|
let s = s.as_slice::<f32>()?;
|
||||||
|
let _s = match l.contiguous_offsets() {
|
||||||
|
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||||
|
Some((o1, o2)) => &s[o1..o2],
|
||||||
|
};
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
s: &candle::CudaStorage,
|
||||||
|
l: &Layout,
|
||||||
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
|
let device = s.device().clone();
|
||||||
|
let s = s.as_cuda_slice::<f32>()?;
|
||||||
|
let s = match l.contiguous_offsets() {
|
||||||
|
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||||
|
Some((o1, o2)) => s, // TODO: slice with o1 and o2
|
||||||
|
};
|
||||||
|
let s: std::result::Result<_, candle::cuda_backend::CudaError> =
|
||||||
|
s.try_clone().map_err(|v| v.into());
|
||||||
|
let s = s?;
|
||||||
|
let s = candle::CudaStorage::wrap_cuda_slice(s, device);
|
||||||
|
Ok((s, l.shape().clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
|
||||||
|
println!("{t}");
|
||||||
|
let t = t.custom_op1(LayerNorm)?;
|
||||||
|
println!("{t}");
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user