mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Refactor the hierarchy.
This commit is contained in:
16
candle-kernels/Cargo.toml
Normal file
16
candle-kernels/Cargo.toml
Normal file
@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
repository = "https://github.com/LaurentMazare/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT/Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
|
||||
[build-dependencies]
|
||||
glob = "0.3.1"
|
||||
rayon = "1.7.0"
|
4
candle-kernels/README.md
Normal file
4
candle-kernels/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
# candle-kernels
|
||||
|
||||
This crate contains CUDA kernels used from candle. Some of these implementations
|
||||
come from the [dfdx crate](https://github.com/coreylowman/dfdx).
|
223
candle-kernels/build.rs
Normal file
223
candle-kernels/build.rs
Normal file
@ -0,0 +1,223 @@
|
||||
use std::io::Write;
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
cuda::set_include_dir();
|
||||
let kernel_paths = cuda::build_ptx();
|
||||
// println!("cargo:warning=kernels {kernel_paths:?}");
|
||||
|
||||
let mut file = std::fs::File::create("src/lib.rs").unwrap();
|
||||
for kernel_path in kernel_paths {
|
||||
let name = kernel_path.file_stem().unwrap().to_str().unwrap();
|
||||
file.write_all(
|
||||
format!(
|
||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
|
||||
name.to_uppercase().replace('.', "_"),
|
||||
name
|
||||
)
|
||||
.as_bytes(),
|
||||
)
|
||||
.unwrap();
|
||||
file.write_all(&[b'\n']).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
mod cuda {
|
||||
pub fn set_include_dir() {
|
||||
use std::path::PathBuf;
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
// We can't actually set a env!() value from another crate,
|
||||
// so we have to do that here.
|
||||
|
||||
// use PathBuf;
|
||||
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
#[allow(unused)]
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
#[allow(unused)]
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
|
||||
#[cfg(feature = "ci-check")]
|
||||
let root: PathBuf = "ci".into();
|
||||
|
||||
#[cfg(not(feature = "ci-check"))]
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.unwrap();
|
||||
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
}
|
||||
|
||||
pub fn build_ptx() -> Vec<std::path::PathBuf> {
|
||||
use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
let out_dir = std::env::var("OUT_DIR").unwrap();
|
||||
let kernel_paths: Vec<PathBuf> = glob::glob("src/*.cu")
|
||||
.unwrap()
|
||||
.map(|p| p.unwrap())
|
||||
.collect();
|
||||
let mut include_directories: Vec<PathBuf> = glob::glob("src/**/*.cuh")
|
||||
.unwrap()
|
||||
.map(|p| p.unwrap())
|
||||
.collect();
|
||||
|
||||
for path in &mut include_directories {
|
||||
println!("cargo:rerun-if-changed={}", path.display());
|
||||
let destination =
|
||||
std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap());
|
||||
std::fs::copy(path.clone(), destination).unwrap();
|
||||
// remove the filename from the path so it's just the directory
|
||||
path.pop();
|
||||
}
|
||||
|
||||
include_directories.sort();
|
||||
include_directories.dedup();
|
||||
|
||||
#[allow(unused)]
|
||||
let include_options: Vec<String> = include_directories
|
||||
.into_iter()
|
||||
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
#[cfg(feature = "ci-check")]
|
||||
{
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=ci");
|
||||
|
||||
for mut kernel_path in kernel_paths.into_iter() {
|
||||
kernel_path.set_extension("ptx");
|
||||
|
||||
let mut ptx_path: PathBuf = out_dir.clone().into();
|
||||
ptx_path.push(kernel_path.as_path().file_name().unwrap());
|
||||
std::fs::File::create(ptx_path).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ci-check"))]
|
||||
{
|
||||
// let start = std::time::Instant::now();
|
||||
|
||||
// Grab compute code from nvidia-smi
|
||||
let mut compute_cap = {
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.expect("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(lines.next().unwrap(), "compute_cap");
|
||||
let cap = lines.next().unwrap().replace('.', "");
|
||||
cap.parse::<usize>().unwrap()
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let max_nvcc_code = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
if !codes.contains(&compute_cap) {
|
||||
panic!("nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}.");
|
||||
}
|
||||
*codes.last().unwrap()
|
||||
};
|
||||
|
||||
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||
// then choose the highest gpu code in nvcc
|
||||
if compute_cap > max_nvcc_code {
|
||||
println!("cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}.");
|
||||
compute_cap = max_nvcc_code;
|
||||
}
|
||||
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
compute_cap = compute_cap_str.parse::<usize>().unwrap();
|
||||
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||
}
|
||||
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||
|
||||
kernel_paths
|
||||
.iter()
|
||||
.for_each(|p| println!("cargo:rerun-if-changed={}", p.display()));
|
||||
|
||||
let children = kernel_paths
|
||||
.par_iter()
|
||||
.flat_map(|p| {
|
||||
let mut output = p.clone();
|
||||
output.set_extension("ptx");
|
||||
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
|
||||
|
||||
if output_filename.exists(){
|
||||
None
|
||||
}else{
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("--ptx")
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.args(["--output-directory", &out_dir])
|
||||
// Flash attention only
|
||||
// .arg("--expt-relaxed-constexpr")
|
||||
.args(&include_options)
|
||||
.arg(p);
|
||||
// println!(
|
||||
// "cargo:warning={command:?}");
|
||||
Some((p, command.spawn()
|
||||
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
|
||||
}})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (kernel_path, child) in children {
|
||||
let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
}
|
||||
|
||||
// println!(
|
||||
// "cargo:warning=Compiled {:?} cuda kernels in {:?}",
|
||||
// n,
|
||||
// start.elapsed()
|
||||
// );
|
||||
}
|
||||
kernel_paths
|
||||
}
|
||||
}
|
37
candle-kernels/src/affine.cu
Normal file
37
candle-kernels/src/affine.cu
Normal file
@ -0,0 +1,37 @@
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
#define AFFINE_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const TYPENAME mul, \
|
||||
const TYPENAME add \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = x * mul + add; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
TYPENAME x = inp ? inp[strided_i] : out[i]; \
|
||||
out[i] = x * mul + add; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
AFFINE_OP(__half, affine_f16)
|
||||
#endif
|
||||
|
||||
AFFINE_OP(float, affine_f32)
|
||||
AFFINE_OP(double, affine_f64)
|
||||
AFFINE_OP(uint32_t, affine_u32)
|
22
candle-kernels/src/binary.cu
Normal file
22
candle-kernels/src/binary.cu
Normal file
@ -0,0 +1,22 @@
|
||||
#include "binary_op_macros.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
BINARY_OP(__half, badd_f16, x + y)
|
||||
BINARY_OP(__half, bdiv_f16, x / y)
|
||||
BINARY_OP(__half, bmul_f16, x * y)
|
||||
BINARY_OP(__half, bsub_f16, x - y)
|
||||
#endif
|
||||
|
||||
BINARY_OP(float, badd_f32, x + y)
|
||||
BINARY_OP(double, badd_f64, x + y);
|
||||
BINARY_OP(uint32_t, badd_u32, x + y);
|
||||
BINARY_OP(float, bdiv_f32, x / y)
|
||||
BINARY_OP(double, bdiv_f64, x / y);
|
||||
BINARY_OP(uint32_t, bdiv_u32, x / y);
|
||||
BINARY_OP(float, bmul_f32, x * y)
|
||||
BINARY_OP(double, bmul_f64, x * y);
|
||||
BINARY_OP(uint32_t, bmul_u32, x * y);
|
||||
BINARY_OP(float, bsub_f32, x - y)
|
||||
BINARY_OP(double, bsub_f64, x - y);
|
||||
BINARY_OP(uint32_t, bsub_u32, x - y);
|
65
candle-kernels/src/binary_op_macros.cuh
Normal file
65
candle-kernels/src/binary_op_macros.cuh
Normal file
@ -0,0 +1,65 @@
|
||||
#include "cuda_utils.cuh"
|
||||
|
||||
#define BINARY_OP(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *dims_and_strides, \
|
||||
const TYPENAME *lhs, \
|
||||
const TYPENAME *rhs, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = dims_and_strides; \
|
||||
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
|
||||
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
|
||||
bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \
|
||||
bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \
|
||||
if (lhs_cont && rhs_cont) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = lhs ? lhs[i] : out[i]; \
|
||||
TYPENAME y = rhs ? rhs[i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} else if (lhs_cont) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned int tmp_i = i; \
|
||||
unsigned int rhs_i = 0; \
|
||||
for (int d = num_dims - 1; d >= 0; d--) { \
|
||||
unsigned int i_dim = tmp_i % dims[d]; \
|
||||
rhs_i += i_dim * rhs_strides[d]; \
|
||||
tmp_i /= dims[d]; \
|
||||
} \
|
||||
TYPENAME x = lhs ? lhs[i] : out[i]; \
|
||||
TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} else if (rhs_cont) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned int tmp_i = i; \
|
||||
unsigned int lhs_i = 0; \
|
||||
for (int d = num_dims - 1; d >= 0; d--) { \
|
||||
unsigned int i_dim = tmp_i % dims[d]; \
|
||||
lhs_i += i_dim * lhs_strides[d]; \
|
||||
tmp_i /= dims[d]; \
|
||||
} \
|
||||
TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \
|
||||
TYPENAME y = rhs ? rhs[i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned int tmp_i = i; \
|
||||
unsigned int lhs_i = 0; \
|
||||
unsigned int rhs_i = 0; \
|
||||
for (int d = num_dims - 1; d >= 0; d--) { \
|
||||
unsigned int i_dim = tmp_i % dims[d]; \
|
||||
lhs_i += i_dim * lhs_strides[d]; \
|
||||
rhs_i += i_dim * rhs_strides[d]; \
|
||||
tmp_i /= dims[d]; \
|
||||
} \
|
||||
TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \
|
||||
TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} \
|
||||
} \
|
48
candle-kernels/src/cast.cu
Normal file
48
candle-kernels/src/cast.cu
Normal file
@ -0,0 +1,48 @@
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const SRC_TYPENAME *inp, \
|
||||
DST_TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
out[i] = inp[i]; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
out[i] = inp[strided_i]; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
CAST_OP(__half, __half, cast_f16_f16)
|
||||
|
||||
CAST_OP(__half, uint32_t, cast_f16_u32)
|
||||
CAST_OP(__half, float, cast_f16_f32)
|
||||
CAST_OP(__half, double, cast_f16_f64)
|
||||
CAST_OP(uint32_t, __half, cast_u32_f16)
|
||||
CAST_OP(float, __half, cast_f32_f16)
|
||||
CAST_OP(double, __half, cast_f64_f16)
|
||||
#endif
|
||||
|
||||
CAST_OP(uint32_t, uint32_t, cast_u32_u32)
|
||||
CAST_OP(uint32_t, float, cast_u32_f32)
|
||||
CAST_OP(uint32_t, double, cast_u32_f64)
|
||||
|
||||
CAST_OP(float, uint32_t, cast_f32_u32)
|
||||
CAST_OP(float, float, cast_f32_f32)
|
||||
CAST_OP(float, double, cast_f32_f64)
|
||||
|
||||
CAST_OP(double, uint32_t, cast_f64_u32)
|
||||
CAST_OP(double, float, cast_f64_f32)
|
||||
CAST_OP(double, double, cast_f64_f64)
|
171
candle-kernels/src/compatibility.cuh
Normal file
171
candle-kernels/src/compatibility.cuh
Normal file
@ -0,0 +1,171 @@
|
||||
#include "cuda_fp16.h"
|
||||
|
||||
// Table showing which features are supported on which compute capability
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications
|
||||
|
||||
// FIXME: the minimum compute capabilities are just guesses since the table is not specific enough
|
||||
|
||||
// #if __CUDA_ARCH__ < 600
|
||||
// __device__ __forceinline__ __half __hmax(__half a, __half b) {
|
||||
// return __float2half(fmaxf(__half2float(a), __half2float(b)));
|
||||
// }
|
||||
// __device__ __forceinline__ __half __hmin(__half a, __half b) {
|
||||
// return __float2half(fminf(__half2float(a), __half2float(b)));
|
||||
// }
|
||||
// #endif
|
||||
|
||||
#if __CUDA_ARCH__ < 800
|
||||
__device__ __forceinline__ __half __hmax_nan(__half a, __half b) {
|
||||
// return __hisnan(a) ? a : (__hisnan(b) ? b : __hmax(a, b));
|
||||
}
|
||||
__device__ __forceinline__ __half __hmin_nan(__half a, __half b) {
|
||||
// return __hisnan(a) ? a : (__hisnan(b) ? b : __hmin(a, b));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ < 600
|
||||
// Copied from https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions
|
||||
__device__ double atomicAdd(double* address, double val) {
|
||||
unsigned long long int* address_as_ull = (unsigned long long int*)address;
|
||||
unsigned long long int old = *address_as_ull, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,
|
||||
__double_as_longlong(val +
|
||||
__longlong_as_double(assumed)));
|
||||
|
||||
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||
} while (assumed != old);
|
||||
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ < 700
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomicadd
|
||||
// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
|
||||
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
|
||||
__device__ __half atomicAdd(__half *address, __half val) {
|
||||
// unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
// unsigned int old = *address_as_ui;
|
||||
// unsigned int assumed;
|
||||
// bool unaligned = (size_t) address & 2;
|
||||
// do {
|
||||
// assumed = old;
|
||||
// unsigned int hsum;
|
||||
// hsum = unaligned ? (old >> 16) : (old & 0xffff);
|
||||
// hsum = __half_as_ushort(__ushort_as_half(hsum) + val);
|
||||
// old = atomicCAS(address_as_ui, assumed,
|
||||
// unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum
|
||||
// );
|
||||
|
||||
// } while (assumed != old);
|
||||
// return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
__device__ __forceinline__ __half atomicMaxf(__half* address, __half val) {
|
||||
#if __CUDA_ARCH__ < 700
|
||||
// On older GPUs we do not have access to atomicCAS for shorts, so we have to do some trickery.
|
||||
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
|
||||
unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
bool unaligned = (size_t) address & 2;
|
||||
do {
|
||||
assumed = old;
|
||||
unsigned int hmax;
|
||||
hmax = unaligned ? (old >> 16) : (old & 0xffff);
|
||||
hmax = __half_as_ushort(__hmax_nan(val, __ushort_as_half(hmax)));
|
||||
old = atomicCAS(address_as_ui, assumed,
|
||||
unaligned ? (old & 0xffff) | (hmax << 16) : (old & 0xffff0000) | hmax
|
||||
);
|
||||
|
||||
} while (assumed != old);
|
||||
return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
|
||||
#else
|
||||
// Based on https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions
|
||||
unsigned short int* casted_address = (unsigned short int*)address;
|
||||
unsigned short int old = *casted_address;
|
||||
unsigned short int assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(casted_address, assumed, __half_as_ushort(__hmax_nan(val, __ushort_as_half(assumed))));
|
||||
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||
} while (assumed != old);
|
||||
return __ushort_as_half(old);
|
||||
#endif
|
||||
}
|
||||
|
||||
// atomicMax is not implemented for floats,
|
||||
// solution copied https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
|
||||
__device__ __forceinline__ float atomicMaxf(float * addr, float value) {
|
||||
if (signbit(value)) {
|
||||
return __uint_as_float(atomicMin((unsigned int *)addr, __float_as_uint(value)));
|
||||
} else {
|
||||
return __int_as_float(atomicMax((int *)addr, __float_as_int(value)));
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ double atomicMaxf(double * addr, double value) {
|
||||
if (signbit(value)) {
|
||||
return __longlong_as_double(atomicMin((unsigned long long int *)addr, __double_as_longlong(value)));
|
||||
} else {
|
||||
return __longlong_as_double(atomicMax((long long int *)addr, __double_as_longlong(value)));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__device__ __forceinline__ __half atomicMinf(__half* address, __half val) {
|
||||
#if __CUDA_ARCH__ < 700
|
||||
// On older GPUs we do not have access to atomicCAS for shorts, so we have to do some trickery.
|
||||
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
|
||||
unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
bool unaligned = (size_t) address & 2;
|
||||
do {
|
||||
assumed = old;
|
||||
unsigned int hmin;
|
||||
hmin = unaligned ? (old >> 16) : (old & 0xffff);
|
||||
hmin = __half_as_ushort(__hmin_nan(val, __ushort_as_half(hmin)));
|
||||
old = atomicCAS(address_as_ui, assumed,
|
||||
unaligned ? (old & 0xffff) | (hmin << 16) : (old & 0xffff0000) | hmin
|
||||
);
|
||||
|
||||
} while (assumed != old);
|
||||
return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
|
||||
#else
|
||||
// Based on https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions
|
||||
unsigned short int* casted_address = (unsigned short int*)address;
|
||||
unsigned short int old = *casted_address;
|
||||
unsigned short int assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(casted_address, assumed, __half_as_ushort(__hmin_nan(val, __ushort_as_half(assumed))));
|
||||
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||
} while (assumed != old);
|
||||
return __ushort_as_half(old);
|
||||
#endif
|
||||
}
|
||||
|
||||
// atomicMin is not implemented for floats,
|
||||
// solution copied https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
|
||||
__device__ __forceinline__ float atomicMinf(float * addr, float value) {
|
||||
if (signbit(value)) {
|
||||
return __uint_as_float(atomicMax((unsigned int *)addr, __float_as_uint(value)));
|
||||
} else {
|
||||
return __int_as_float(atomicMin((int *)addr, __float_as_int(value)));
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ double atomicMinf(double * addr, double value) {
|
||||
if (signbit(value)) {
|
||||
return __longlong_as_double(atomicMax((unsigned long long int *)addr, __double_as_longlong(value)));
|
||||
} else {
|
||||
return __longlong_as_double(atomicMin((long long int *)addr, __double_as_longlong(value)));
|
||||
}
|
||||
}
|
158
candle-kernels/src/cuda_utils.cuh
Normal file
158
candle-kernels/src/cuda_utils.cuh
Normal file
@ -0,0 +1,158 @@
|
||||
#include "cuda_fp16.h"
|
||||
#include "compatibility.cuh"
|
||||
|
||||
// TODO: This is often used to check that the data is contiguous so that
|
||||
// kernels can be easily mapped. However this only returns true for row
|
||||
// major, if all the inputs are column major, we could apply the fast path
|
||||
// too (but we wouldn't if some of them are row major and some column major).
|
||||
__device__ bool is_contiguous(
|
||||
const size_t num_dims,
|
||||
const size_t *dims,
|
||||
const size_t *strides
|
||||
) {
|
||||
size_t acc = 1;
|
||||
for (unsigned int d = 0; d < num_dims; d++) {
|
||||
unsigned int dim_idx = num_dims - 1 - d;
|
||||
if (acc != strides[dim_idx]) {
|
||||
return false;
|
||||
}
|
||||
acc *= dims[dim_idx];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
__device__ unsigned int get_strided_index(
|
||||
unsigned int idx,
|
||||
const size_t num_dims,
|
||||
const size_t *dims,
|
||||
const size_t *strides
|
||||
) {
|
||||
unsigned int strided_i = 0;
|
||||
for (unsigned int d = 0; d < num_dims; d++) {
|
||||
unsigned int dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
__device__ unsigned int restrided(
|
||||
const unsigned int strided_i,
|
||||
const size_t num_dims,
|
||||
const size_t *dims,
|
||||
const size_t *strides,
|
||||
const size_t *new_strides
|
||||
) {
|
||||
unsigned int idx = 0;
|
||||
for (int d = 0; d < num_dims; d++) {
|
||||
idx += (strides[d] == 0 ? 0 : (strided_i / strides[d]) % dims[d]) * new_strides[d];
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
// Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
|
||||
// Input must be less than or equal to 2 ^ 16
|
||||
// used in reductions
|
||||
__device__ __forceinline__ unsigned int next_power_of_two(unsigned int v) {
|
||||
v--;
|
||||
v |= v >> 1;
|
||||
v |= v >> 2;
|
||||
v |= v >> 4;
|
||||
v |= v >> 8;
|
||||
v++;
|
||||
return v;
|
||||
}
|
||||
|
||||
// Efficiently computes the sum of each chunk in "data" of size chunk_len, and
|
||||
// stores the sums in out[i / chunk_len]
|
||||
template<typename T>
|
||||
__device__ void chunk_sum(
|
||||
const size_t chunk_len,
|
||||
const T data,
|
||||
T* out
|
||||
) {
|
||||
__shared__ T buf[1024];
|
||||
|
||||
// assumes that threads where i >= numel have already exited
|
||||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int block_i = threadIdx.x;
|
||||
|
||||
// Fall back to atomicAdd if chunk_len is small to reduce overhead
|
||||
if (chunk_len <= 2) {
|
||||
atomicAdd(out + i / chunk_len, data);
|
||||
return;
|
||||
}
|
||||
buf[block_i] = data;
|
||||
|
||||
unsigned int chunk_i = i % chunk_len;
|
||||
unsigned int chunk_start = max((int)(block_i - chunk_i), 0);
|
||||
unsigned int chunk_end = min((unsigned int)(block_i + chunk_len - chunk_i), blockDim.x);
|
||||
|
||||
chunk_i = block_i - chunk_start;
|
||||
|
||||
size_t max_chunk_len = min(chunk_end - chunk_start, blockDim.x);
|
||||
size_t incr = next_power_of_two(max_chunk_len) >> 1;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Uses sequential addressing as discussed in
|
||||
// https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf
|
||||
for (; incr > 0; incr >>= 1) {
|
||||
unsigned int block_i_2 = block_i + incr;
|
||||
|
||||
if (block_i_2 < chunk_end && chunk_i < incr) {
|
||||
// This is sound because __syncthreads and the conditions above
|
||||
// ensure that no data races occur
|
||||
buf[block_i] += buf[block_i_2];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (block_i == chunk_start) {
|
||||
atomicAdd(out + i / chunk_len, buf[block_i]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool isnang(float a) { return isnan(a); }
|
||||
__device__ __forceinline__ bool isnang(double a) { return isnan(a); }
|
||||
__device__ __forceinline__ float recipg(float a) { return 1.0 / a; }
|
||||
__device__ __forceinline__ double recipg(double a) { return 1.0 / a; }
|
||||
__device__ __forceinline__ float cosg(float a) { return cosf(a); }
|
||||
__device__ __forceinline__ double cosg(double a) { return cos(a); }
|
||||
__device__ __forceinline__ float sing(float a) { return sinf(a); }
|
||||
__device__ __forceinline__ double sing(double a) { return sin(a); }
|
||||
__device__ __forceinline__ float sqrtg(float a) { return sqrtf(a); }
|
||||
__device__ __forceinline__ double sqrtg(double a) { return sqrt(a); }
|
||||
__device__ __forceinline__ float powg(float a, float b) { return powf(a, b); }
|
||||
__device__ __forceinline__ double powg(double a, double b) { return pow(a, b); }
|
||||
__device__ __forceinline__ float tanhg(float a) { return tanhf(a); }
|
||||
__device__ __forceinline__ double tanhg(double a) { return tanh(a); }
|
||||
__device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); }
|
||||
__device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); }
|
||||
__device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); }
|
||||
__device__ __forceinline__ double ming(double a, double b) { return fmin(a, b); }
|
||||
__device__ __forceinline__ float logg(float a) { return logf(a); }
|
||||
__device__ __forceinline__ double logg(double a) { return log(a); }
|
||||
__device__ __forceinline__ float expg(float a) { return expf(a); }
|
||||
__device__ __forceinline__ double expg(double a) { return exp(a); }
|
||||
__device__ __forceinline__ float absg(float a) { return fabsf(a); }
|
||||
__device__ __forceinline__ double absg(double a) { return fabs(a); }
|
||||
__device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); }
|
||||
__device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); }
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
|
||||
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
|
||||
__device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); }
|
||||
__device__ __forceinline__ __half cosg(__half a) { return hcos(a); }
|
||||
__device__ __forceinline__ __half sing(__half a) { return hsin(a); }
|
||||
__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; }
|
||||
__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); }
|
||||
__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); }
|
||||
__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); }
|
||||
__device__ __forceinline__ __half logg(__half a) { return hlog(a); }
|
||||
__device__ __forceinline__ __half expg(__half a) { return hexp(a); }
|
||||
__device__ __forceinline__ __half absg(__half a) { return __habs(a); }
|
||||
__device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); }
|
||||
#endif
|
38
candle-kernels/src/embeddings.cu
Normal file
38
candle-kernels/src/embeddings.cu
Normal file
@ -0,0 +1,38 @@
|
||||
// WARNING: THIS IS ONLY VALID ASSUMING THAT inp IS CONTIGUOUS!
|
||||
// TODO: proper error reporting when ids are larger than v_size.
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
#define EMB_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const uint32_t *ids, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const size_t h_size, \
|
||||
const size_t v_size \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
memcpy(&out[i * h_size], &inp[ids[i] * h_size], h_size * sizeof(TYPENAME)); \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
memcpy(&out[i * h_size], &inp[ids[strided_i] * h_size], h_size * sizeof(TYPENAME)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
EMB_OP(__half, emb_f16)
|
||||
#endif
|
||||
|
||||
EMB_OP(float, emb_f32)
|
||||
EMB_OP(double, emb_f64)
|
||||
EMB_OP(uint32_t, emb_u32)
|
11
candle-kernels/src/fill.cu
Normal file
11
candle-kernels/src/fill.cu
Normal file
@ -0,0 +1,11 @@
|
||||
#include "cuda_fp16.h"
|
||||
|
||||
template<typename T>
|
||||
__device__ void fill_with(T *buf, T value, const size_t numel) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
buf[i] = value;
|
||||
}
|
||||
}
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
|
8
candle-kernels/src/lib.rs
Normal file
8
candle-kernels/src/lib.rs
Normal file
@ -0,0 +1,8 @@
|
||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
52
candle-kernels/src/reduce.cu
Normal file
52
candle-kernels/src/reduce.cu
Normal file
@ -0,0 +1,52 @@
|
||||
// TODO: Use a proper distributed reduction rather than atomicAdd.
|
||||
// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
#define SUM_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t num_sum_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
const size_t *sum_dims_l = info + 2*num_dims; \
|
||||
const size_t *sum_dims_s = info + 2*num_dims + num_sum_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
size_t dst_index = i; \
|
||||
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
|
||||
size_t stride = sum_dims_s[nd]; \
|
||||
size_t pre = dst_index / stride; \
|
||||
size_t post = dst_index % stride; \
|
||||
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
|
||||
} \
|
||||
atomicAdd(out + dst_index, inp[i]); \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
size_t dst_index = i; \
|
||||
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
|
||||
size_t stride = sum_dims_s[nd]; \
|
||||
size_t pre = dst_index / stride; \
|
||||
size_t post = dst_index % stride; \
|
||||
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
|
||||
} \
|
||||
atomicAdd(out + dst_index, inp[strided_i]); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
SUM_OP(__half, sum_f16)
|
||||
#endif
|
||||
|
||||
SUM_OP(float, sum_f32)
|
||||
SUM_OP(double, sum_f64)
|
||||
SUM_OP(uint32_t, sum_u32)
|
41
candle-kernels/src/ternary.cu
Normal file
41
candle-kernels/src/ternary.cu
Normal file
@ -0,0 +1,41 @@
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
#define WHERE_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const uint32_t *ids, \
|
||||
const TYPENAME *t, \
|
||||
const TYPENAME *f, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
const size_t *strides_t = info + 2*num_dims; \
|
||||
const size_t *strides_f = info + 2*num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides) \
|
||||
&& is_contiguous(num_dims, dims, strides_f) \
|
||||
&& is_contiguous(num_dims, dims, strides_t)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
out[i] = ids[i] ? t[i] : f[i]; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
unsigned strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
unsigned strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
WHERE_OP(__half, where_f16)
|
||||
#endif
|
||||
|
||||
WHERE_OP(float, where_f32)
|
||||
WHERE_OP(double, where_f64)
|
||||
WHERE_OP(uint32_t, where_u32)
|
69
candle-kernels/src/unary.cu
Normal file
69
candle-kernels/src/unary.cu
Normal file
@ -0,0 +1,69 @@
|
||||
#include "cuda_utils.cuh"
|
||||
|
||||
#define UNARY_OP(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
TYPENAME x = inp ? inp[strided_i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
__device__ T gelu_fwd(T x) {
|
||||
T x_sq = x * x;
|
||||
T x_cube = x_sq * x;
|
||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanhg(static_cast<T>(M_2_SQRTPI * M_SQRT1_2) * alpha));
|
||||
}
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
UNARY_OP(__half, ucopy_f16, x)
|
||||
UNARY_OP(__half, uneg_f16, -x)
|
||||
UNARY_OP(__half, uexp_f16, expg(x))
|
||||
UNARY_OP(__half, ulog_f16, logg(x))
|
||||
UNARY_OP(__half, usin_f16, sing(x))
|
||||
UNARY_OP(__half, ucos_f16, cosg(x))
|
||||
UNARY_OP(__half, uabs_f16, absg(x))
|
||||
UNARY_OP(__half, usqr_f16, x*x)
|
||||
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
||||
UNARY_OP(__half, gelu_f16, gelu_fwd(x))
|
||||
#endif
|
||||
|
||||
UNARY_OP(float, ucopy_f32, x)
|
||||
UNARY_OP(double, ucopy_f64, x)
|
||||
UNARY_OP(float, uneg_f32, -x)
|
||||
UNARY_OP(double, uneg_f64, -x)
|
||||
UNARY_OP(float, uexp_f32, expg(x))
|
||||
UNARY_OP(double, uexp_f64, expg(x))
|
||||
UNARY_OP(float, ulog_f32, logg(x))
|
||||
UNARY_OP(double, ulog_f64, logg(x))
|
||||
UNARY_OP(float, usin_f32, sing(x))
|
||||
UNARY_OP(double, usin_f64, sing(x))
|
||||
UNARY_OP(float, ucos_f32, cosg(x))
|
||||
UNARY_OP(double, ucos_f64, cosg(x))
|
||||
UNARY_OP(float, uabsg_f32, absg(x))
|
||||
UNARY_OP(double, uabsg_f64, absg(x))
|
||||
UNARY_OP(float, usqr_f32, x*x)
|
||||
UNARY_OP(double, usqr_f64, x*x)
|
||||
UNARY_OP(float, usqrt_f32, sqrtg(x))
|
||||
UNARY_OP(double, usqrt_f64, sqrtg(x))
|
||||
UNARY_OP(float, gelu_f32, gelu_fwd(x))
|
||||
UNARY_OP(double, gelu_f64, gelu_fwd(x))
|
Reference in New Issue
Block a user