Merge pull request #10 from LaurentMazare/f16

Add support for f16 and bf16
This commit is contained in:
Laurent Mazare
2023-06-27 05:59:59 +01:00
committed by GitHub
11 changed files with 770 additions and 312 deletions

View File

@ -20,6 +20,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: check
args: --no-default-features
test:
name: Test Suite
@ -38,6 +39,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: test
args: --no-default-features
fmt:
name: Rustfmt
@ -69,4 +71,4 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: clippy
args: --tests --examples -- -D warnings
args: --no-default-features --tests --examples -- -D warnings

View File

@ -18,11 +18,13 @@ members = [
[dependencies]
safetensors = "0.3.1"
thiserror = "1"
cudarc = { version = "0.9.9", optional = true }
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
candle-kernels = { path = "kernels", optional = true }
gemm = "0.15.4"
zip = { version = "0.6.6", default-features=false }
byteorder = "1.4.3"
half = { version = "2.3.1", features = ["num-traits"] }
num-traits = "0.2.15"
[dev-dependencies]
anyhow = "1"
@ -31,5 +33,5 @@ rand = "0.8.5"
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
[features]
default = []
default = ["cuda"]
cuda = ["dep:cudarc", "dep:candle-kernels"]

View File

@ -28,6 +28,10 @@ extern "C" __global__ void FN_NAME( \
} \
} \
#if __CUDA_ARCH__ >= 530
AFFINE_OP(__half, affine_f16)
#endif
AFFINE_OP(float, affine_f32)
AFFINE_OP(double, affine_f64)
AFFINE_OP(uint32_t, affine_u32)

View File

@ -29,6 +29,10 @@ extern "C" __global__ void FN_NAME( \
} \
} \
#if __CUDA_ARCH__ >= 530
EMB_OP(__half, emb_f16)
#endif
EMB_OP(float, emb_f32)
EMB_OP(double, emb_f64)
EMB_OP(uint32_t, emb_u32)

View File

@ -43,6 +43,10 @@ extern "C" __global__ void FN_NAME( \
} \
} \
#if __CUDA_ARCH__ >= 530
SUM_OP(__half, sum_f16)
#endif
SUM_OP(float, sum_f32)
SUM_OP(double, sum_f64)
SUM_OP(uint32_t, sum_u32)

View File

@ -32,6 +32,10 @@ extern "C" __global__ void FN_NAME( \
} \
} \
#if __CUDA_ARCH__ >= 530
WHERE_OP(__half, where_f16)
#endif
WHERE_OP(float, where_f32)
WHERE_OP(double, where_f64)
WHERE_OP(uint32_t, where_u32)

View File

@ -1,6 +1,7 @@
use crate::op::{BinaryOp, UnaryOp};
use crate::{DType, Error, Result, Shape, StridedIndex};
use gemm::{gemm, Parallelism};
use half::{bf16, f16};
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
@ -9,6 +10,8 @@ use gemm::{gemm, Parallelism};
#[derive(Debug, Clone)]
pub enum CpuStorage {
U32(Vec<u32>),
BF16(Vec<bf16>),
F16(Vec<f16>),
F32(Vec<f32>),
F64(Vec<f64>),
}
@ -43,10 +46,37 @@ fn wcond<T: Copy>(
}
}
macro_rules! map1 {
($v: expr, $fn: ident, $( $args:expr ),*) => {{
let v = match $v {
CpuStorage::BF16(__s) => CpuStorage::BF16($fn::<bf16>(__s, $($args),*)?),
CpuStorage::F16(__s) => CpuStorage::F16($fn::<f16>(__s, $($args),*)?),
CpuStorage::F32(__s) => CpuStorage::F32($fn::<f32>(__s, $($args),*)?),
CpuStorage::F64(__s) => CpuStorage::F64($fn::<f64>(__s, $($args),*)?),
CpuStorage::U32(__s) => CpuStorage::U32($fn::<u32>(__s, $($args),*)?),
};
Ok(v)
}};
}
fn sum_impl1<T: Copy + num_traits::NumAssign>(
src: &[T],
dst_shape: &Shape,
src_dims: &[usize],
stride: &[usize],
to_dst_index: impl Fn(usize) -> usize,
) -> Result<Vec<T>> {
let mut dst = vec![T::zero(); dst_shape.elem_count()];
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
dst[to_dst_index(unstr_index)] += src[src_index];
}
Ok(dst)
}
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
vs: &[T],
shape: &Shape,
stride: &[usize],
vs: &[T],
mut f: F,
) -> Vec<U> {
if shape.is_contiguous(stride) {
@ -80,11 +110,11 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
}
}
fn take<T: Copy>(
fn take_impl1<T: Copy>(
vs: &[T],
ids: &[u32],
shape: &Shape,
stride: &[usize],
vs: &[T],
vocab_size: usize,
hidden_size: usize,
) -> Result<Vec<T>> {
@ -132,6 +162,8 @@ impl CpuStorage {
pub fn dtype(&self) -> DType {
match self {
Self::U32(_) => DType::U32,
Self::BF16(_) => DType::BF16,
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
Self::F64(_) => DType::F64,
}
@ -148,40 +180,104 @@ impl CpuStorage {
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
// TODO: find a way around the quadratic number of cases below.
match (self, dtype) {
(Self::U32(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
}
(Self::BF16(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, |v| v);
Ok(Self::BF16(data))
}
(Self::F16(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32()));
Ok(Self::BF16(data))
}
(Self::F32(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, bf16::from_f32);
Ok(Self::BF16(data))
}
(Self::F64(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, bf16::from_f64);
Ok(Self::BF16(data))
}
(Self::U32(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
}
(Self::BF16(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32()));
Ok(Self::F16(data))
}
(Self::F16(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, |v| v);
Ok(Self::F16(data))
}
(Self::F32(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, f16::from_f32);
Ok(Self::F16(data))
}
(Self::F64(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, f16::from_f64);
Ok(Self::F16(data))
}
(Self::U32(storage), DType::F32) => {
let data = unary_map(shape, stride, storage, |v| v as f32);
let data = unary_map(storage, shape, stride, |v| v as f32);
Ok(Self::F32(data))
}
(Self::BF16(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32());
Ok(Self::F32(data))
}
(Self::F16(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32());
Ok(Self::F32(data))
}
(Self::F32(storage), DType::F32) => {
let data = unary_map(shape, stride, storage, |v| v);
let data = unary_map(storage, shape, stride, |v| v);
Ok(Self::F32(data))
}
(Self::F64(storage), DType::F32) => {
let data = unary_map(shape, stride, storage, |v| v as f32);
let data = unary_map(storage, shape, stride, |v| v as f32);
Ok(Self::F32(data))
}
(Self::U32(storage), DType::U32) => {
let data = unary_map(shape, stride, storage, |v| v);
let data = unary_map(storage, shape, stride, |v| v);
Ok(Self::U32(data))
}
(Self::BF16(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
Ok(Self::U32(data))
}
(Self::F16(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
Ok(Self::U32(data))
}
(Self::F32(storage), DType::U32) => {
let data = unary_map(shape, stride, storage, |v| v as u32);
let data = unary_map(storage, shape, stride, |v| v as u32);
Ok(Self::U32(data))
}
(Self::F64(storage), DType::U32) => {
let data = unary_map(shape, stride, storage, |v| v as u32);
let data = unary_map(storage, shape, stride, |v| v as u32);
Ok(Self::U32(data))
}
(Self::U32(storage), DType::F64) => {
let data = unary_map(shape, stride, storage, |v| v as f64);
let data = unary_map(storage, shape, stride, |v| v as f64);
Ok(Self::F64(data))
}
(Self::BF16(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v.to_f64());
Ok(Self::F64(data))
}
(Self::F16(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v.to_f64());
Ok(Self::F64(data))
}
(Self::F32(storage), DType::F64) => {
let data = unary_map(shape, stride, storage, |v| v as f64);
let data = unary_map(storage, shape, stride, |v| v as f64);
Ok(Self::F64(data))
}
(Self::F64(storage), DType::F64) => {
let data = unary_map(shape, stride, storage, |v| v);
let data = unary_map(storage, shape, stride, |v| v);
Ok(Self::F64(data))
}
}
@ -214,29 +310,7 @@ impl CpuStorage {
dst_index
};
// TODO: Maybe provide an implementation with higher precision accumulators?
match self {
Self::F32(src) => {
let mut dst = vec![0f32; dst_shape.elem_count()];
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
dst[to_dst_index(unstr_index)] += src[src_index];
}
Ok(Self::F32(dst))
}
Self::F64(src) => {
let mut dst = vec![0f64; dst_shape.elem_count()];
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
dst[to_dst_index(unstr_index)] += src[src_index];
}
Ok(Self::F64(dst))
}
Self::U32(src) => {
let mut dst = vec![0u32; dst_shape.elem_count()];
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
dst[to_dst_index(unstr_index)] += src[src_index];
}
Ok(Self::U32(dst))
}
}
map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index)
}
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
@ -246,6 +320,42 @@ impl CpuStorage {
let prod_pre_dim = dims[..dim].iter().product();
let prod_post_dim = dims[dim + 1..].iter().product();
match self {
Self::BF16(storage) => {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += storage[idx].to_f64();
idx += prod_post_dim
}
let sum = bf16::from_f64(sum);
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
storage[idx] /= sum;
idx += prod_post_dim
}
}
}
}
Self::F16(storage) => {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += storage[idx].to_f64();
idx += prod_post_dim
}
let sum = f16::from_f64(sum);
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
storage[idx] /= sum;
idx += prod_post_dim
}
}
}
}
Self::F32(storage) => {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
@ -297,17 +407,29 @@ impl CpuStorage {
Self::U32(storage) => {
let mul = mul as u32;
let add = add as u32;
let data = unary_map(shape, stride, storage, |v| v * mul + add);
let data = unary_map(storage, shape, stride, |v| v * mul + add);
Ok(Self::U32(data))
}
Self::BF16(storage) => {
let mul = bf16::from_f64(mul);
let add = bf16::from_f64(add);
let data = unary_map(storage, shape, stride, |v| v * mul + add);
Ok(Self::BF16(data))
}
Self::F16(storage) => {
let mul = f16::from_f64(mul);
let add = f16::from_f64(add);
let data = unary_map(storage, shape, stride, |v| v * mul + add);
Ok(Self::F16(data))
}
Self::F32(storage) => {
let mul = mul as f32;
let add = add as f32;
let data = unary_map(shape, stride, storage, |v| v * mul + add);
let data = unary_map(storage, shape, stride, |v| v * mul + add);
Ok(Self::F32(data))
}
Self::F64(storage) => {
let data = unary_map(shape, stride, storage, |v| v * mul + add);
let data = unary_map(storage, shape, stride, |v| v * mul + add);
Ok(Self::F64(data))
}
}
@ -315,16 +437,24 @@ impl CpuStorage {
pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
match self {
Self::BF16(storage) => {
let data = unary_map(storage, shape, stride, B::bf16);
Ok(Self::BF16(data))
}
Self::F16(storage) => {
let data = unary_map(storage, shape, stride, B::f16);
Ok(Self::F16(data))
}
Self::F32(storage) => {
let data = unary_map(shape, stride, storage, B::f32);
let data = unary_map(storage, shape, stride, B::f32);
Ok(Self::F32(data))
}
Self::F64(storage) => {
let data = unary_map(shape, stride, storage, B::f64);
let data = unary_map(storage, shape, stride, B::f64);
Ok(Self::F64(data))
}
Self::U32(storage) => {
let data = unary_map(shape, stride, storage, B::u32);
let data = unary_map(storage, shape, stride, B::u32);
Ok(Self::U32(data))
}
}
@ -338,6 +468,14 @@ impl CpuStorage {
rhs_stride: &[usize],
) -> Result<Self> {
match (self, rhs) {
(Self::BF16(lhs), Self::BF16(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16);
Ok(Self::BF16(data))
}
(Self::F16(lhs), Self::F16(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16);
Ok(Self::F16(data))
}
(Self::F32(lhs), Self::F32(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32);
Ok(Self::F32(data))
@ -376,6 +514,12 @@ impl CpuStorage {
(Self::U32(src), Self::U32(dst)) => {
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
}
(Self::BF16(src), Self::BF16(dst)) => {
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
}
(Self::F16(src), Self::F16(dst)) => {
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
}
(Self::F32(src), Self::F32(dst)) => {
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
}
@ -406,6 +550,14 @@ impl CpuStorage {
// TODO: Support types that could be casted to a boolean.
let pred = self.as_slice::<u32>()?;
match (t, f) {
(Self::BF16(t), Self::BF16(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::BF16(data))
}
(Self::F16(t), Self::F16(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::F16(data))
}
(Self::F32(t), Self::F32(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::F32(data))
@ -435,20 +587,7 @@ impl CpuStorage {
vocab_size: usize,
) -> Result<Self> {
let ids = self.as_slice::<u32>()?;
match vs {
CpuStorage::F32(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::F32(storage))
}
CpuStorage::F64(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::F64(storage))
}
CpuStorage::U32(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::U32(storage))
}
}
map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size)
}
pub(crate) fn matmul_impl(
@ -479,63 +618,176 @@ impl CpuStorage {
}
}
let mut dst = vec![0.0; b * m * n];
let dst_shape: Shape = (m, n).into();
let dst_strides = dst_shape.stride_contiguous();
let dst_rs = dst_strides[0];
let dst_cs = dst_strides[1];
for step in 0..b {
let lhs_p = &self.as_slice::<f32>()?[step * a_skip..];
let rhs_p = &rhs.as_slice::<f32>()?[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
gemm(
// m: usize,
m,
// n: usize,
n,
// k: usize,
k,
// dst: *mut T,
dst_p.as_mut_ptr(),
// dst_cs: isize,
dst_cs as isize,
// dst_rs: isize,
dst_rs as isize,
// read_dst: bool,
false,
// lhs: *const T,
lhs_p.as_ptr(),
// lhs_cs: isize,
lhs_cs as isize,
// lhs_rs: isize,
lhs_rs as isize,
// rhs: *const T,
rhs_p.as_ptr(),
// rhs_cs: isize,
rhs_cs as isize,
// rhs_rs: isize,
rhs_rs as isize,
// alpha: T,
1.0,
// beta: T,
1.0,
// conj_dst: bool,
false,
// conj_lhs: bool,
false,
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
)
match (self, rhs) {
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
let mut dst = vec![f16::ZERO; b * m * n];
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
gemm(
// m: usize,
m,
// n: usize,
n,
// k: usize,
k,
// dst: *mut T,
dst_p.as_mut_ptr(),
// dst_cs: isize,
dst_cs as isize,
// dst_rs: isize,
dst_rs as isize,
// read_dst: bool,
false,
// lhs: *const T,
lhs_p.as_ptr(),
// lhs_cs: isize,
lhs_cs as isize,
// lhs_rs: isize,
lhs_rs as isize,
// rhs: *const T,
rhs_p.as_ptr(),
// rhs_cs: isize,
rhs_cs as isize,
// rhs_rs: isize,
rhs_rs as isize,
// alpha: T,
f16::ONE,
// beta: T,
f16::ONE,
// conj_dst: bool,
false,
// conj_lhs: bool,
false,
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
)
}
}
Ok(Self::F16(dst))
}
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
let mut dst = vec![0f32; b * m * n];
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
gemm(
// m: usize,
m,
// n: usize,
n,
// k: usize,
k,
// dst: *mut T,
dst_p.as_mut_ptr(),
// dst_cs: isize,
dst_cs as isize,
// dst_rs: isize,
dst_rs as isize,
// read_dst: bool,
false,
// lhs: *const T,
lhs_p.as_ptr(),
// lhs_cs: isize,
lhs_cs as isize,
// lhs_rs: isize,
lhs_rs as isize,
// rhs: *const T,
rhs_p.as_ptr(),
// rhs_cs: isize,
rhs_cs as isize,
// rhs_rs: isize,
rhs_rs as isize,
// alpha: T,
1f32,
// beta: T,
1f32,
// conj_dst: bool,
false,
// conj_lhs: bool,
false,
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
)
}
}
Ok(Self::F32(dst))
}
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
let mut dst = vec![0f64; b * m * n];
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
gemm(
// m: usize,
m,
// n: usize,
n,
// k: usize,
k,
// dst: *mut T,
dst_p.as_mut_ptr(),
// dst_cs: isize,
dst_cs as isize,
// dst_rs: isize,
dst_rs as isize,
// read_dst: bool,
false,
// lhs: *const T,
lhs_p.as_ptr(),
// lhs_cs: isize,
lhs_cs as isize,
// lhs_rs: isize,
lhs_rs as isize,
// rhs: *const T,
rhs_p.as_ptr(),
// rhs_cs: isize,
rhs_cs as isize,
// rhs_rs: isize,
rhs_rs as isize,
// alpha: T,
1f64,
// beta: T,
1f64,
// conj_dst: bool,
false,
// conj_lhs: bool,
false,
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
)
}
}
Ok(Self::F64(dst))
}
_ => {
// This should be covered by the dtype check above.
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: rhs.dtype(),
op: "matmul",
})
}
}
let c = Self::F32(dst);
Ok(c)
}
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
@ -545,6 +797,14 @@ impl CpuStorage {
let data = vec![1u32; elem_count];
Self::U32(data)
}
DType::BF16 => {
let data = vec![bf16::ONE; elem_count];
Self::BF16(data)
}
DType::F16 => {
let data = vec![f16::ONE; elem_count];
Self::F16(data)
}
DType::F32 => {
let data = vec![1f32; elem_count];
Self::F32(data)
@ -563,6 +823,14 @@ impl CpuStorage {
let data = vec![0u32; elem_count];
Self::U32(data)
}
DType::BF16 => {
let data = vec![bf16::ZERO; elem_count];
Self::BF16(data)
}
DType::F16 => {
let data = vec![f16::ZERO; elem_count];
Self::F16(data)
}
DType::F32 => {
let data = vec![0f32; elem_count];
Self::F32(data)

View File

@ -2,6 +2,7 @@ use crate::{CpuStorage, DType, Shape};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
use half::{bf16, f16};
use std::sync::Arc;
/// cudarc related errors
@ -38,6 +39,12 @@ pub enum CudaError {
expected: DType,
got: DType,
},
#[error("{cuda} when loading {module_name}")]
Load {
cuda: cudarc::driver::DriverError,
module_name: &'static str,
},
}
type Result<T> = std::result::Result<T, CudaError>;
@ -97,6 +104,14 @@ impl CudaDevice {
let data = self.alloc_zeros::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
@ -124,6 +139,22 @@ impl CudaDevice {
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }?;
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
let params = (&data, bf16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count) }?;
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
let params = (&data, f16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }?;
@ -157,6 +188,14 @@ impl CudaDevice {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::U32(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::F32(data)
@ -178,7 +217,8 @@ impl CudaDevice {
ptx: &'static str,
) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
self.load_ptx(ptx.into(), module_name, &[module_name])?;
self.load_ptx(ptx.into(), module_name, &[module_name])
.map_err(|cuda| CudaError::Load { cuda, module_name })?;
}
self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
@ -190,6 +230,8 @@ impl CudaDevice {
#[derive(Debug)]
enum CudaStorageSlice {
U32(CudaSlice<u32>),
BF16(CudaSlice<bf16>),
F16(CudaSlice<f16>),
F32(CudaSlice<f32>),
F64(CudaSlice<f64>),
}
@ -265,6 +307,8 @@ impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> {
let slice = match &self.slice {
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?),
CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?),
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
};
@ -275,6 +319,8 @@ impl CudaStorage {
pub fn dtype(&self) -> DType {
match self.slice {
CudaStorageSlice::U32(_) => DType::U32,
CudaStorageSlice::BF16(_) => DType::BF16,
CudaStorageSlice::F16(_) => DType::F16,
CudaStorageSlice::F32(_) => DType::F32,
CudaStorageSlice::F64(_) => DType::F64,
}
@ -310,6 +356,40 @@ impl CudaStorage {
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
let params = (
el_count,
dims.len(),
&ds,
arg,
&out,
bf16::from_f64(mul),
bf16::from_f64(add),
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el_count) }?;
let params = (
el_count,
dims.len(),
&ds,
arg,
&out,
f16::from_f64(mul),
f16::from_f64(add),
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
@ -361,6 +441,22 @@ impl CudaStorage {
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
let out = dev.alloc_zeros::<bf16>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f16>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f32>(dst_el)?;
@ -402,6 +498,24 @@ impl CudaStorage {
CudaStorageSlice::U32(_arg) => {
todo!("No unary kernels for u32");
}
CudaStorageSlice::BF16(arg) => {
let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
@ -438,6 +552,24 @@ impl CudaStorage {
let dev = self.device();
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
let slice = match (&self.slice, &rhs.slice) {
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
@ -479,6 +611,16 @@ impl CudaStorage {
let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::U32(cpu_storage))
}
CudaStorageSlice::BF16(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::BF16(cpu_storage))
}
CudaStorageSlice::F16(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::F16(cpu_storage))
}
CudaStorageSlice::F32(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?;
@ -515,6 +657,24 @@ impl CudaStorage {
let dev = self.device();
let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
let slice = match (&t.slice, &f.slice) {
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
@ -573,7 +733,7 @@ impl CudaStorage {
let slice = match &rhs.slice {
// The kernels below assume that rhs is contiguous.
CudaStorageSlice::U32(arg) => {
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
@ -581,6 +741,24 @@ impl CudaStorage {
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
@ -614,6 +792,19 @@ impl CudaStorage {
let elem_count = b * m * n;
let dev = &self.device;
let slice = match (&self.slice, &rhs.slice) {
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => {
todo!("bf16")
}
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?;
let mut out = unsafe { dev.alloc::<f16>(elem_count) }?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
@ -657,6 +848,32 @@ impl CudaStorage {
let dev = &self.device;
let ds = dev.htod_copy([dims, src_stride].concat())?;
match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?
}
}
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?
}
}
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
@ -670,6 +887,19 @@ impl CudaStorage {
unsafe { func.launch(cfg, params) }?
}
}
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?
}
}
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);

View File

@ -3,6 +3,8 @@ use crate::{CpuStorage, Error, Result};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
U32,
BF16,
F16,
F32,
F64,
}
@ -11,6 +13,8 @@ impl DType {
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U32 => 4,
Self::BF16 => 2,
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
}
@ -76,5 +80,7 @@ macro_rules! with_dtype {
};
}
with_dtype!(u32, U32);
with_dtype!(half::f16, F16);
with_dtype!(half::bf16, BF16);
with_dtype!(f32, F32);
with_dtype!(f64, F64);

View File

@ -27,6 +27,7 @@
//! ```
use crate::{DType, Device, Error, Result, Shape, Tensor};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read, Write};
@ -80,6 +81,8 @@ impl Header {
.collect::<Vec<_>>()
.join(",");
let descr = match self.descr {
DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
DType::F16 => "f2",
DType::F32 => "f4",
DType::F64 => "f8",
DType::U32 => "u4",
@ -152,7 +155,7 @@ impl Header {
// int64, int32, int16, int8,
// uint8, and bool.
match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
// "e" | "f2" => DType::F16,
"e" | "f2" => DType::F16,
"f" | "f4" => DType::F32,
"d" | "f8" => DType::F64,
// "i" | "i4" => DType::S32,
@ -191,9 +194,20 @@ impl Header {
}
impl Tensor {
// TODO: Add the possibility to read directly to a device?
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
let elem_count = shape.elem_count();
match dtype {
DType::BF16 => {
let mut data_t = vec![bf16::ZERO; elem_count];
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::F16 => {
let mut data_t = vec![f16::ZERO; elem_count];
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::F32 => {
let mut data_t = vec![0f32; elem_count];
reader.read_f32_into::<LittleEndian>(&mut data_t)?;
@ -289,6 +303,18 @@ impl Tensor {
f.write_all(header.as_bytes())?;
let elem_count = self.elem_count();
match self.dtype() {
DType::BF16 => {
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {

310
src/op.rs
View File

@ -1,4 +1,6 @@
use crate::Tensor;
use half::{bf16, f16};
use num_traits::float::Float;
#[derive(Clone)]
pub(crate) enum Op {
@ -40,10 +42,13 @@ pub(crate) enum Op {
pub(crate) trait UnaryOp {
const NAME: &'static str;
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
// contiguous case separately as it's easy to optimize things out there.
const KERNEL_BF16: &'static str;
const KERNEL_F16: &'static str;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
const KERNEL_U32: &'static str;
fn bf16(v1: bf16) -> bf16;
fn f16(v1: f16) -> f16;
fn f32(v1: f32) -> f32;
fn f64(v1: f64) -> f64;
fn u32(v1: u32) -> u32;
@ -51,11 +56,13 @@ pub(crate) trait UnaryOp {
pub(crate) trait BinaryOp {
const NAME: &'static str;
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
// contiguous case separately as it's easy to optimize things out there.
const KERNEL_BF16: &'static str;
const KERNEL_F16: &'static str;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
const KERNEL_U32: &'static str;
fn bf16(v1: bf16, v2: bf16) -> bf16;
fn f16(v1: f16, v2: f16) -> f16;
fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64;
fn u32(v1: u32, v2: u32) -> u32;
@ -75,215 +82,116 @@ pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
impl BinaryOp for Add {
const NAME: &'static str = "add";
const KERNEL_F32: &'static str = "badd_f32";
const KERNEL_F64: &'static str = "badd_f64";
const KERNEL_U32: &'static str = "badd_u32";
fn f32(v1: f32, v2: f32) -> f32 {
v1 + v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 + v2
}
fn u32(v1: u32, v2: u32) -> u32 {
v1 + v2
}
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr) => {
impl BinaryOp for $op {
const NAME: &'static str = $name;
const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16");
const KERNEL_F16: &'static str = concat!("b", $name, "_f16");
const KERNEL_F32: &'static str = concat!("b", $name, "_f32");
const KERNEL_F64: &'static str = concat!("b", $name, "_f64");
const KERNEL_U32: &'static str = concat!("b", $name, "_u32");
fn bf16(v1: bf16, v2: bf16) -> bf16 {
$e(v1, v2)
}
fn f16(v1: f16, v2: f16) -> f16 {
$e(v1, v2)
}
fn f32(v1: f32, v2: f32) -> f32 {
$e(v1, v2)
}
fn f64(v1: f64, v2: f64) -> f64 {
$e(v1, v2)
}
fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2)
}
}
};
}
impl BinaryOp for Sub {
const NAME: &'static str = "sub";
const KERNEL_F32: &'static str = "bsub_f32";
const KERNEL_F64: &'static str = "bsub_f64";
const KERNEL_U32: &'static str = "bsub_u32";
fn f32(v1: f32, v2: f32) -> f32 {
v1 - v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 - v2
}
fn u32(v1: u32, v2: u32) -> u32 {
v1 - v2
}
bin_op!(Add, "add", |v1, v2| v1 + v2);
bin_op!(Sub, "sub", |v1, v2| v1 - v2);
bin_op!(Mul, "mul", |v1, v2| v1 * v2);
bin_op!(Div, "div", |v1, v2| v1 / v2);
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOp for $op {
const NAME: &'static str = $name;
const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16");
const KERNEL_F16: &'static str = concat!("u", $name, "_f16");
const KERNEL_F32: &'static str = concat!("u", $name, "_f32");
const KERNEL_F64: &'static str = concat!("u", $name, "_f64");
const KERNEL_U32: &'static str = concat!("u", $name, "_u32");
fn bf16($a: bf16) -> bf16 {
$e
}
fn f16($a: f16) -> f16 {
$e
}
fn f32($a: f32) -> f32 {
$e
}
fn f64($a: f64) -> f64 {
$e
}
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
}
};
}
impl BinaryOp for Mul {
const NAME: &'static str = "mul";
const KERNEL_F32: &'static str = "bmul_f32";
const KERNEL_F64: &'static str = "bmul_f64";
const KERNEL_U32: &'static str = "bmul_u32";
fn f32(v1: f32, v2: f32) -> f32 {
v1 * v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 * v2
}
fn u32(v1: u32, v2: u32) -> u32 {
v1 * v2
}
}
impl BinaryOp for Div {
const NAME: &'static str = "div";
const KERNEL_F32: &'static str = "bdiv_f32";
const KERNEL_F64: &'static str = "bdiv_f64";
const KERNEL_U32: &'static str = "bdiv_u32";
fn f32(v1: f32, v2: f32) -> f32 {
v1 / v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 / v2
}
fn u32(v1: u32, v2: u32) -> u32 {
v1 / v2
}
}
impl UnaryOp for Exp {
const NAME: &'static str = "exp";
fn f32(v1: f32) -> f32 {
v1.exp()
}
fn f64(v1: f64) -> f64 {
v1.exp()
}
fn u32(v1: u32) -> u32 {
(v1 as f64).exp() as u32
}
const KERNEL_F32: &'static str = "uexp_f32";
const KERNEL_F64: &'static str = "uexp_f64";
}
impl UnaryOp for Log {
const NAME: &'static str = "log";
fn f32(v1: f32) -> f32 {
v1.ln()
}
fn f64(v1: f64) -> f64 {
v1.ln()
}
fn u32(v1: u32) -> u32 {
(v1 as f64).ln() as u32
}
const KERNEL_F32: &'static str = "ulog_f32";
const KERNEL_F64: &'static str = "ulog_f64";
}
impl UnaryOp for Sin {
const NAME: &'static str = "sin";
fn f32(v1: f32) -> f32 {
v1.sin()
}
fn f64(v1: f64) -> f64 {
v1.sin()
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_F32: &'static str = "usin_f32";
const KERNEL_F64: &'static str = "usin_f64";
}
impl UnaryOp for Cos {
const NAME: &'static str = "cos";
fn f32(v1: f32) -> f32 {
v1.cos()
}
fn f64(v1: f64) -> f64 {
v1.cos()
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_F32: &'static str = "ucos_f32";
const KERNEL_F64: &'static str = "ucos_f64";
}
impl UnaryOp for Abs {
const NAME: &'static str = "abs";
fn f32(v1: f32) -> f32 {
v1.abs()
}
fn f64(v1: f64) -> f64 {
v1.abs()
}
fn u32(v1: u32) -> u32 {
v1
}
const KERNEL_F32: &'static str = "uabs_f32";
const KERNEL_F64: &'static str = "uabs_f64";
}
impl UnaryOp for Neg {
const NAME: &'static str = "neg";
fn f32(v1: f32) -> f32 {
-v1
}
fn f64(v1: f64) -> f64 {
-v1
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_F32: &'static str = "uneg_f32";
const KERNEL_F64: &'static str = "uneg_f64";
}
impl UnaryOp for Sqr {
const NAME: &'static str = "sqr";
fn f32(v1: f32) -> f32 {
v1 * v1
}
fn f64(v1: f64) -> f64 {
v1 * v1
}
fn u32(v: u32) -> u32 {
v * v
}
const KERNEL_F32: &'static str = "usqr_f32";
const KERNEL_F64: &'static str = "usqr_f64";
}
impl UnaryOp for Sqrt {
const NAME: &'static str = "sqrt";
fn f32(v1: f32) -> f32 {
v1.sqrt()
}
fn f64(v1: f64) -> f64 {
v1.sqrt()
}
fn u32(v: u32) -> u32 {
(v as f64).sqrt() as u32
}
const KERNEL_F32: &'static str = "usqrt_f32";
const KERNEL_F64: &'static str = "usqrt_f64";
}
unary_op!(Exp, "exp", v, v.exp());
unary_op!(Log, "log", v, v.ln());
unary_op!(Sin, "sin", v, v.sin());
unary_op!(Cos, "cos", v, v.cos());
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Sqr, "sqr", v, v * v);
unary_op!(Sqrt, "sqrt", v, v.sqrt());
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
#[inline]
pub fn gelu_f32(v: f32) -> f32 {
0.5 * v
* (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
#[inline]
pub fn gelu_f64(v: f64) -> f64 {
0.5 * v
* (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
impl UnaryOp for Gelu {
const NAME: &'static str = "gelu";
fn f32(v1: f32) -> f32 {
gelu_f32(v1)
fn bf16(v: bf16) -> bf16 {
bf16::from_f32_const(0.5)
* v
* (bf16::ONE
+ bf16::tanh(
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
* v
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
))
}
fn f64(v1: f64) -> f64 {
gelu_f64(v1)
fn f16(v: f16) -> f16 {
f16::from_f32_const(0.5)
* v
* (f16::ONE
+ f16::tanh(
(f16::from_f32_const(2.0) / f16::PI).sqrt()
* v
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
))
}
fn u32(v1: u32) -> u32 {
gelu_f64(v1 as f64) as u32
fn f32(v: f32) -> f32 {
0.5 * v
* (1.0
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
fn f64(v: f64) -> f64 {
0.5 * v
* (1.0
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_BF16: &'static str = "gelu_bf16";
const KERNEL_F16: &'static str = "gelu_f16";
const KERNEL_F32: &'static str = "gelu_f32";
const KERNEL_F64: &'static str = "gelu_f64";
const KERNEL_U32: &'static str = "gelu_u32";
}