diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index f9fefe17..d2cc3e41 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -771,6 +771,50 @@ pub struct CudaStorage { device: CudaDevice, } +pub trait CudaDType: Sized { + fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice>; + fn wrap_cuda_slice(s: CudaSlice, dev: CudaDevice) -> CudaStorage; +} + +macro_rules! cuda_dtype { + ($ty:ty, $dtype:ident) => { + impl CudaDType for $ty { + fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice> { + match &s.slice { + CudaStorageSlice::$dtype(data) => Ok(&data), + _ => Err(crate::Error::UnexpectedDType { + expected: DType::$dtype, + got: s.dtype(), + msg: "unexpected dtype", + } + .bt()), + } + } + + fn wrap_cuda_slice(slice: CudaSlice, device: CudaDevice) -> CudaStorage { + let slice = CudaStorageSlice::$dtype(slice); + CudaStorage { slice, device } + } + } + }; +} +cuda_dtype!(u8, U8); +cuda_dtype!(u32, U32); +cuda_dtype!(f16, F16); +cuda_dtype!(bf16, BF16); +cuda_dtype!(f32, F32); +cuda_dtype!(f64, F64); + +impl CudaStorage { + pub fn wrap_cuda_slice(slice: CudaSlice, device: CudaDevice) -> CudaStorage { + T::wrap_cuda_slice(slice, device) + } + + pub fn as_cuda_slice(&self) -> Result<&CudaSlice> { + T::as_cuda_slice(self) + } +} + fn gemm_config( alpha: T, beta: T, diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 24435e81..f940a937 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -30,6 +30,9 @@ tracing-chrome = { workspace = true } tracing-subscriber = { workspace = true } wav = { workspace = true } +[build-dependencies] +anyhow = { workspace = true } + [features] default = [] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] diff --git a/candle-examples/build.rs b/candle-examples/build.rs new file mode 100644 index 00000000..7f69fa77 --- /dev/null +++ b/candle-examples/build.rs @@ -0,0 +1,231 @@ +#![allow(unused)] +use anyhow::{Context, Result}; +use std::io::Write; +use std::path::PathBuf; + +struct KernelDirectories { + kernel_dir: &'static str, + rust_target: &'static str, +} + +const DIRS: [KernelDirectories; 1] = [KernelDirectories { + kernel_dir: "examples/custom-ops/kernels/", + rust_target: "examples/custom-ops/cuda_kernels.rs", +}]; + +impl KernelDirectories { + fn maybe_build_ptx( + &self, + cu_file: &std::path::Path, + ptx_file: &std::path::Path, + compute_cap: usize, + ) -> Result<()> { + let should_compile = if ptx_file.exists() { + let ptx_modified = ptx_file.metadata()?.modified()?; + let cu_modified = cu_file.metadata()?.modified()?; + cu_modified.duration_since(ptx_modified).is_ok() + } else { + true + }; + if should_compile { + #[cfg(feature = "cuda")] + { + let mut command = std::process::Command::new("nvcc"); + let out_dir = ptx_file.parent().context("no parent for ptx file")?; + command + .arg(format!("--gpu-architecture=sm_{compute_cap}")) + .arg("--ptx") + .args(["--default-stream", "per-thread"]) + .args(["--output-directory", out_dir.to_str().unwrap()]) + .arg(format!("-I/{}", self.kernel_dir)) + .arg(cu_file); + let output = command + .spawn() + .context("failed spawning nvcc")? + .wait_with_output()?; + if !output.status.success() { + anyhow::bail!( + "nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ) + } + } + #[cfg(not(feature = "cuda"))] + std::fs::OpenOptions::new() + .create(true) + .write(true) + .open(ptx_file)?; + } + Ok(()) + } + fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> { + println!("cargo:rerun-if-changed={}", self.kernel_dir); + let kernel_dir = PathBuf::from(self.kernel_dir); + let out_dir = out_dir.join(self.kernel_dir); + if !out_dir.exists() { + std::fs::create_dir_all(&out_dir)?; + } + let mut cu_files = vec![]; + let mut cuh_files = vec![]; + for file in std::fs::read_dir(kernel_dir)?.flatten() { + let file = file.path(); + match file.extension().and_then(|v| v.to_str()) { + Some("cu") => cu_files.push(file), + Some("cuh") => cuh_files.push(file), + _ => {} + } + } + + let mut ptx_paths = vec![]; + for cu_file in cu_files.iter() { + let file_stem = cu_file + .file_stem() + .with_context(|| format!("no stem {cu_file:?}"))?; + let file_stem = file_stem.to_string_lossy().into_owned(); + let ptx_file = out_dir.join(&format!("{file_stem}.ptx")); + self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?; + ptx_paths.push(ptx_file); + } + + let regenerate_rs_file = true; + if regenerate_rs_file { + let mut file = std::fs::File::create(self.rust_target)?; + for ptx_path in ptx_paths { + let name = ptx_path + .file_stem() + .context("empty stem")? + .to_string_lossy(); + let const_definition = format!( + r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#, + name.to_uppercase().replace('.', "_"), + self.kernel_dir, + ); + file.write_all(const_definition.as_bytes())?; + file.write_all(b"\n")?; + } + } + Ok(()) + } +} + +fn main() -> Result<()> { + println!("cargo:rerun-if-changed=build.rs"); + + let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?; + let out_dir = PathBuf::from(out_dir); + #[cfg(feature = "cuda")] + set_cuda_include_dir()?; + #[cfg(feature = "cuda")] + let compute_cap = compute_cap()?; + #[cfg(not(feature = "cuda"))] + let compute_cap = 0; + for d in DIRS { + d.process(&out_dir, compute_cap)? + } + Ok(()) +} + +fn set_cuda_include_dir() -> Result<()> { + // NOTE: copied from cudarc build.rs. + let env_vars = [ + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDNN_LIB", + ]; + let env_vars = env_vars + .into_iter() + .map(std::env::var) + .filter_map(Result::ok) + .map(Into::::into); + + let roots = [ + "/usr", + "/usr/local/cuda", + "/opt/cuda", + "/usr/lib/cuda", + "C:/Program Files/NVIDIA GPU Computing Toolkit", + "C:/CUDA", + ]; + let roots = roots.into_iter().map(Into::::into); + let root = env_vars + .chain(roots) + .find(|path| path.join("include").join("cuda.h").is_file()) + .context("cannot find include/cuda.h")?; + println!( + "cargo:rustc-env=CUDA_INCLUDE_DIR={}", + root.join("include").display() + ); + Ok(()) +} + +#[allow(unused)] +fn compute_cap() -> Result { + // Grab compute code from nvidia-smi + let mut compute_cap = { + let out = std::process::Command::new("nvidia-smi") + .arg("--query-gpu=compute_cap") + .arg("--format=csv") + .output() + .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; + let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; + let mut lines = out.lines(); + assert_eq!( + lines.next().context("missing line in stdout")?, + "compute_cap" + ); + let cap = lines + .next() + .context("missing line in stdout")? + .replace('.', ""); + cap.parse::() + .with_context(|| format!("cannot parse as int {cap}"))? + }; + + // Grab available GPU codes from nvcc and select the highest one + let max_nvcc_code = { + let out = std::process::Command::new("nvcc") + .arg("--list-gpu-code") + .output() + .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); + let out = std::str::from_utf8(&out.stdout).unwrap(); + + let out = out.lines().collect::>(); + let mut codes = Vec::with_capacity(out.len()); + for code in out { + let code = code.split('_').collect::>(); + if !code.is_empty() && code.contains(&"sm") { + if let Ok(num) = code[1].parse::() { + codes.push(num); + } + } + } + codes.sort(); + if !codes.contains(&compute_cap) { + anyhow::bail!( + "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}." + ); + } + *codes.last().unwrap() + }; + + // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc, + // then choose the highest gpu code in nvcc + if compute_cap > max_nvcc_code { + println!( + "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}." + ); + compute_cap = max_nvcc_code; + } + + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + compute_cap = compute_cap_str + .parse::() + .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?; + println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP"); + } + println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); + Ok(compute_cap) +} diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs new file mode 100644 index 00000000..07d18342 --- /dev/null +++ b/candle-examples/examples/custom-ops/cuda_kernels.rs @@ -0,0 +1 @@ +pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx")); diff --git a/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu new file mode 100644 index 00000000..07ab8639 --- /dev/null +++ b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu @@ -0,0 +1,37 @@ +#include "reduction_utils.cuh" + +template +__device__ void +rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size] + const scalar_t *__restrict__ input, // [num_tokens, hidden_size] + const scalar_t *__restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float)input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; + } +} +extern "C" __global__ void rms_norm_kernel_f32( + float *__restrict__ out, // [num_tokens, hidden_size] + const float *__restrict__ input, // [num_tokens, hidden_size] + const float *__restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, + const int hidden_size) { + rms_norm_kernel(out, input, weight, epsilon, num_tokens, hidden_size); +} + diff --git a/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh b/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh new file mode 100644 index 00000000..d5765f4f --- /dev/null +++ b/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh @@ -0,0 +1,46 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +template __inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(0xffffffff, val, mask, 32); + return val; +} + +/* Calculate the sum of all elements in a block */ +template __inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs new file mode 100644 index 00000000..adc7abd7 --- /dev/null +++ b/candle-examples/examples/custom-ops/main.rs @@ -0,0 +1,65 @@ +#![allow(dead_code)] +#![allow(unused)] + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use clap::Parser; + +use candle::backend::BackendStorage; +use candle::cpu_backend; +use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, +} + +struct LayerNorm; + +impl CustomOp1 for LayerNorm { + fn name(&self) -> &'static str { + "layer-norm" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + let s = s.as_slice::()?; + let _s = match l.contiguous_offsets() { + None => Err(Error::Wrapped("input has to be contiguous".into()))?, + Some((o1, o2)) => &s[o1..o2], + }; + todo!() + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s: &candle::CudaStorage, + l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + let device = s.device().clone(); + let s = s.as_cuda_slice::()?; + let s = match l.contiguous_offsets() { + None => Err(Error::Wrapped("input has to be contiguous".into()))?, + Some((o1, o2)) => s, // TODO: slice with o1 and o2 + }; + let s: std::result::Result<_, candle::cuda_backend::CudaError> = + s.try_clone().map_err(|v| v.into()); + let s = s?; + let s = candle::CudaStorage::wrap_cuda_slice(s, device); + Ok((s, l.shape().clone())) + } +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?; + println!("{t}"); + let t = t.custom_op1(LayerNorm)?; + println!("{t}"); + Ok(()) +}