mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Start adding f16/bf16 support.
This commit is contained in:
@ -18,11 +18,13 @@ members = [
|
||||
[dependencies]
|
||||
safetensors = "0.3.1"
|
||||
thiserror = "1"
|
||||
cudarc = { version = "0.9.9", optional = true }
|
||||
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
|
||||
candle-kernels = { path = "kernels", optional = true }
|
||||
gemm = "0.15.4"
|
||||
zip = { version = "0.6.6", default-features=false }
|
||||
byteorder = "1.4.3"
|
||||
half = "2.3.1"
|
||||
num-traits = "0.2.15"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1"
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::op::{BinaryOp, UnaryOp};
|
||||
use crate::{DType, Error, Result, Shape, StridedIndex};
|
||||
use gemm::{gemm, Parallelism};
|
||||
use half::{bf16, f16};
|
||||
|
||||
// TODO: Think about whether we would be better off with a dtype and
|
||||
// a buffer as an owned slice of bytes.
|
||||
@ -9,6 +10,8 @@ use gemm::{gemm, Parallelism};
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CpuStorage {
|
||||
U32(Vec<u32>),
|
||||
BF16(Vec<bf16>),
|
||||
F16(Vec<f16>),
|
||||
F32(Vec<f32>),
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
@ -132,6 +135,8 @@ impl CpuStorage {
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Self::U32(_) => DType::U32,
|
||||
Self::BF16(_) => DType::BF16,
|
||||
Self::F16(_) => DType::F16,
|
||||
Self::F32(_) => DType::F32,
|
||||
Self::F64(_) => DType::F64,
|
||||
}
|
||||
@ -545,6 +550,14 @@ impl CpuStorage {
|
||||
let data = vec![1u32; elem_count];
|
||||
Self::U32(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = vec![bf16::ONE; elem_count];
|
||||
Self::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = vec![f16::ONE; elem_count];
|
||||
Self::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = vec![1f32; elem_count];
|
||||
Self::F32(data)
|
||||
@ -563,6 +576,14 @@ impl CpuStorage {
|
||||
let data = vec![0u32; elem_count];
|
||||
Self::U32(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = vec![bf16::ZERO; elem_count];
|
||||
Self::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = vec![f16::ZERO; elem_count];
|
||||
Self::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = vec![0f32; elem_count];
|
||||
Self::F32(data)
|
||||
|
@ -2,6 +2,7 @@ use crate::{CpuStorage, DType, Shape};
|
||||
use candle_kernels as kernels;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// cudarc related errors
|
||||
@ -97,6 +98,14 @@ impl CudaDevice {
|
||||
let data = self.alloc_zeros::<u32>(elem_count)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc_zeros::<f16>(elem_count)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc_zeros::<f32>(elem_count)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
@ -190,6 +199,8 @@ impl CudaDevice {
|
||||
#[derive(Debug)]
|
||||
enum CudaStorageSlice {
|
||||
U32(CudaSlice<u32>),
|
||||
BF16(CudaSlice<bf16>),
|
||||
F16(CudaSlice<f16>),
|
||||
F32(CudaSlice<f32>),
|
||||
F64(CudaSlice<f64>),
|
||||
}
|
||||
@ -265,6 +276,8 @@ impl CudaStorage {
|
||||
pub fn try_clone(&self) -> Result<Self> {
|
||||
let slice = match &self.slice {
|
||||
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
|
||||
CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?),
|
||||
CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?),
|
||||
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
|
||||
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
|
||||
};
|
||||
@ -275,6 +288,8 @@ impl CudaStorage {
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self.slice {
|
||||
CudaStorageSlice::U32(_) => DType::U32,
|
||||
CudaStorageSlice::BF16(_) => DType::BF16,
|
||||
CudaStorageSlice::F16(_) => DType::F16,
|
||||
CudaStorageSlice::F32(_) => DType::F32,
|
||||
CudaStorageSlice::F64(_) => DType::F64,
|
||||
}
|
||||
|
@ -3,6 +3,8 @@ use crate::{CpuStorage, Error, Result};
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
U32,
|
||||
BF16,
|
||||
F16,
|
||||
F32,
|
||||
F64,
|
||||
}
|
||||
@ -11,6 +13,8 @@ impl DType {
|
||||
pub fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::U32 => 4,
|
||||
Self::BF16 => 2,
|
||||
Self::F16 => 2,
|
||||
Self::F32 => 4,
|
||||
Self::F64 => 8,
|
||||
}
|
||||
@ -76,5 +80,7 @@ macro_rules! with_dtype {
|
||||
};
|
||||
}
|
||||
with_dtype!(u32, U32);
|
||||
with_dtype!(half::f16, F16);
|
||||
with_dtype!(half::bf16, BF16);
|
||||
with_dtype!(f32, F32);
|
||||
with_dtype!(f64, F64);
|
||||
|
16
src/npy.rs
16
src/npy.rs
@ -80,6 +80,8 @@ impl Header {
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
let descr = match self.descr {
|
||||
DType::BF16 => todo!("bf16"),
|
||||
DType::F16 => "f2",
|
||||
DType::F32 => "f4",
|
||||
DType::F64 => "f8",
|
||||
DType::U32 => "u4",
|
||||
@ -152,7 +154,7 @@ impl Header {
|
||||
// int64, int32, int16, int8,
|
||||
// uint8, and bool.
|
||||
match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
|
||||
// "e" | "f2" => DType::F16,
|
||||
"e" | "f2" => DType::F16,
|
||||
"f" | "f4" => DType::F32,
|
||||
"d" | "f8" => DType::F64,
|
||||
// "i" | "i4" => DType::S32,
|
||||
@ -194,6 +196,12 @@ impl Tensor {
|
||||
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::BF16 => {
|
||||
todo!("bf16")
|
||||
}
|
||||
DType::F16 => {
|
||||
todo!("f16")
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data_t = vec![0f32; elem_count];
|
||||
reader.read_f32_into::<LittleEndian>(&mut data_t)?;
|
||||
@ -289,6 +297,12 @@ impl Tensor {
|
||||
f.write_all(header.as_bytes())?;
|
||||
let elem_count = self.elem_count();
|
||||
match self.dtype() {
|
||||
DType::BF16 => {
|
||||
todo!("bf16")
|
||||
}
|
||||
DType::F16 => {
|
||||
todo!("f16")
|
||||
}
|
||||
DType::F32 => {
|
||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
|
||||
|
310
src/op.rs
310
src/op.rs
@ -1,4 +1,6 @@
|
||||
use crate::Tensor;
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum Op {
|
||||
@ -40,10 +42,13 @@ pub(crate) enum Op {
|
||||
|
||||
pub(crate) trait UnaryOp {
|
||||
const NAME: &'static str;
|
||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
||||
// contiguous case separately as it's easy to optimize things out there.
|
||||
const KERNEL_BF16: &'static str;
|
||||
const KERNEL_F16: &'static str;
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
const KERNEL_U32: &'static str;
|
||||
fn bf16(v1: bf16) -> bf16;
|
||||
fn f16(v1: f16) -> f16;
|
||||
fn f32(v1: f32) -> f32;
|
||||
fn f64(v1: f64) -> f64;
|
||||
fn u32(v1: u32) -> u32;
|
||||
@ -51,11 +56,13 @@ pub(crate) trait UnaryOp {
|
||||
|
||||
pub(crate) trait BinaryOp {
|
||||
const NAME: &'static str;
|
||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
||||
// contiguous case separately as it's easy to optimize things out there.
|
||||
const KERNEL_BF16: &'static str;
|
||||
const KERNEL_F16: &'static str;
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
const KERNEL_U32: &'static str;
|
||||
fn bf16(v1: bf16, v2: bf16) -> bf16;
|
||||
fn f16(v1: f16, v2: f16) -> f16;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
fn u32(v1: u32, v2: u32) -> u32;
|
||||
@ -75,215 +82,116 @@ pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
|
||||
impl BinaryOp for Add {
|
||||
const NAME: &'static str = "add";
|
||||
const KERNEL_F32: &'static str = "badd_f32";
|
||||
const KERNEL_F64: &'static str = "badd_f64";
|
||||
const KERNEL_U32: &'static str = "badd_u32";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 + v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 + v2
|
||||
}
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
v1 + v2
|
||||
}
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr) => {
|
||||
impl BinaryOp for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16");
|
||||
const KERNEL_F16: &'static str = concat!("b", $name, "_f16");
|
||||
const KERNEL_F32: &'static str = concat!("b", $name, "_f32");
|
||||
const KERNEL_F64: &'static str = concat!("b", $name, "_f64");
|
||||
const KERNEL_U32: &'static str = concat!("b", $name, "_u32");
|
||||
fn bf16(v1: bf16, v2: bf16) -> bf16 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
fn f16(v1: f16, v2: f16) -> f16 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl BinaryOp for Sub {
|
||||
const NAME: &'static str = "sub";
|
||||
const KERNEL_F32: &'static str = "bsub_f32";
|
||||
const KERNEL_F64: &'static str = "bsub_f64";
|
||||
const KERNEL_U32: &'static str = "bsub_u32";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 - v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 - v2
|
||||
}
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
v1 - v2
|
||||
}
|
||||
bin_op!(Add, "add", |v1, v2| v1 + v2);
|
||||
bin_op!(Sub, "sub", |v1, v2| v1 - v2);
|
||||
bin_op!(Mul, "mul", |v1, v2| v1 * v2);
|
||||
bin_op!(Div, "div", |v1, v2| v1 / v2);
|
||||
|
||||
macro_rules! unary_op {
|
||||
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
||||
impl UnaryOp for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16");
|
||||
const KERNEL_F16: &'static str = concat!("u", $name, "_f16");
|
||||
const KERNEL_F32: &'static str = concat!("u", $name, "_f32");
|
||||
const KERNEL_F64: &'static str = concat!("u", $name, "_f64");
|
||||
const KERNEL_U32: &'static str = concat!("u", $name, "_u32");
|
||||
fn bf16($a: bf16) -> bf16 {
|
||||
$e
|
||||
}
|
||||
fn f16($a: f16) -> f16 {
|
||||
$e
|
||||
}
|
||||
fn f32($a: f32) -> f32 {
|
||||
$e
|
||||
}
|
||||
fn f64($a: f64) -> f64 {
|
||||
$e
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl BinaryOp for Mul {
|
||||
const NAME: &'static str = "mul";
|
||||
const KERNEL_F32: &'static str = "bmul_f32";
|
||||
const KERNEL_F64: &'static str = "bmul_f64";
|
||||
const KERNEL_U32: &'static str = "bmul_u32";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 * v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 * v2
|
||||
}
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
v1 * v2
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Div {
|
||||
const NAME: &'static str = "div";
|
||||
const KERNEL_F32: &'static str = "bdiv_f32";
|
||||
const KERNEL_F64: &'static str = "bdiv_f64";
|
||||
const KERNEL_U32: &'static str = "bdiv_u32";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 / v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 / v2
|
||||
}
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
v1 / v2
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Exp {
|
||||
const NAME: &'static str = "exp";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.exp()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.exp()
|
||||
}
|
||||
fn u32(v1: u32) -> u32 {
|
||||
(v1 as f64).exp() as u32
|
||||
}
|
||||
const KERNEL_F32: &'static str = "uexp_f32";
|
||||
const KERNEL_F64: &'static str = "uexp_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Log {
|
||||
const NAME: &'static str = "log";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.ln()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.ln()
|
||||
}
|
||||
fn u32(v1: u32) -> u32 {
|
||||
(v1 as f64).ln() as u32
|
||||
}
|
||||
const KERNEL_F32: &'static str = "ulog_f32";
|
||||
const KERNEL_F64: &'static str = "ulog_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sin {
|
||||
const NAME: &'static str = "sin";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.sin()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.sin()
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usin_f32";
|
||||
const KERNEL_F64: &'static str = "usin_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Cos {
|
||||
const NAME: &'static str = "cos";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.cos()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.cos()
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
const KERNEL_F32: &'static str = "ucos_f32";
|
||||
const KERNEL_F64: &'static str = "ucos_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.abs()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.abs()
|
||||
}
|
||||
fn u32(v1: u32) -> u32 {
|
||||
v1
|
||||
}
|
||||
const KERNEL_F32: &'static str = "uabs_f32";
|
||||
const KERNEL_F64: &'static str = "uabs_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Neg {
|
||||
const NAME: &'static str = "neg";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
-v1
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
-v1
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
const KERNEL_F32: &'static str = "uneg_f32";
|
||||
const KERNEL_F64: &'static str = "uneg_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqr {
|
||||
const NAME: &'static str = "sqr";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1 * v1
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1 * v1
|
||||
}
|
||||
fn u32(v: u32) -> u32 {
|
||||
v * v
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usqr_f32";
|
||||
const KERNEL_F64: &'static str = "usqr_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqrt {
|
||||
const NAME: &'static str = "sqrt";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.sqrt()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.sqrt()
|
||||
}
|
||||
fn u32(v: u32) -> u32 {
|
||||
(v as f64).sqrt() as u32
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usqrt_f32";
|
||||
const KERNEL_F64: &'static str = "usqrt_f64";
|
||||
}
|
||||
unary_op!(Exp, "exp", v, v.exp());
|
||||
unary_op!(Log, "log", v, v.ln());
|
||||
unary_op!(Sin, "sin", v, v.sin());
|
||||
unary_op!(Cos, "cos", v, v.cos());
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Sqr, "sqr", v, v * v);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt());
|
||||
|
||||
/// `gelu` operation
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
#[inline]
|
||||
pub fn gelu_f32(v: f32) -> f32 {
|
||||
0.5 * v
|
||||
* (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
/// `gelu` operation
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
#[inline]
|
||||
pub fn gelu_f64(v: f64) -> f64 {
|
||||
0.5 * v
|
||||
* (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
impl UnaryOp for Gelu {
|
||||
const NAME: &'static str = "gelu";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
gelu_f32(v1)
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f32_const(0.5)
|
||||
* v
|
||||
* (bf16::ONE
|
||||
+ bf16::tanh(
|
||||
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
|
||||
* v
|
||||
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
gelu_f64(v1)
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f32_const(0.5)
|
||||
* v
|
||||
* (f16::ONE
|
||||
+ f16::tanh(
|
||||
(f16::from_f32_const(2.0) / f16::PI).sqrt()
|
||||
* v
|
||||
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
}
|
||||
fn u32(v1: u32) -> u32 {
|
||||
gelu_f64(v1 as f64) as u32
|
||||
fn f32(v: f32) -> f32 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
fn f64(v: f64) -> f64 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
const KERNEL_BF16: &'static str = "gelu_bf16";
|
||||
const KERNEL_F16: &'static str = "gelu_f16";
|
||||
const KERNEL_F32: &'static str = "gelu_f32";
|
||||
const KERNEL_F64: &'static str = "gelu_f64";
|
||||
const KERNEL_U32: &'static str = "gelu_u32";
|
||||
}
|
||||
|
Reference in New Issue
Block a user