Process unary functions per block (#180)

* Process unary functions per block.

* Add some inline hints.
This commit is contained in:
Laurent Mazare
2023-07-17 10:22:33 +01:00
committed by GitHub
parent 2a74019ec6
commit 28e1c07304
2 changed files with 53 additions and 3 deletions

View File

@ -95,21 +95,27 @@ macro_rules! bin_op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("b", $name);
const V: Self = $op;
#[inline(always)]
fn bf16(v1: bf16, v2: bf16) -> bf16 {
$e(v1, v2)
}
#[inline(always)]
fn f16(v1: f16, v2: f16) -> f16 {
$e(v1, v2)
}
#[inline(always)]
fn f32(v1: f32, v2: f32) -> f32 {
$e(v1, v2)
}
#[inline(always)]
fn f64(v1: f64, v2: f64) -> f64 {
$e(v1, v2)
}
#[inline(always)]
fn u8(v1: u8, v2: u8) -> u8 {
$e(v1, v2)
}
#[inline(always)]
fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2)
}
@ -128,21 +134,27 @@ macro_rules! unary_op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
#[inline(always)]
fn bf16($a: bf16) -> bf16 {
$e
}
#[inline(always)]
fn f16($a: f16) -> f16 {
$e
}
#[inline(always)]
fn f32($a: f32) -> f32 {
$e
}
#[inline(always)]
fn f64($a: f64) -> f64 {
$e
}
#[inline(always)]
fn u8(_: u8) -> u8 {
todo!("no unary function for u8")
}
#[inline(always)]
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
@ -164,6 +176,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt());
impl UnaryOp for Gelu {
const NAME: &'static str = "gelu";
const V: Self = Gelu;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f32_const(0.5)
* v
@ -174,6 +187,7 @@ impl UnaryOp for Gelu {
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f32_const(0.5)
* v
@ -184,19 +198,23 @@ impl UnaryOp for Gelu {
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
0.5 * v
* (1.0
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn f64(v: f64) -> f64 {
0.5 * v
* (1.0
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
@ -207,21 +225,27 @@ impl UnaryOp for Relu {
const NAME: &'static str = "relu";
const KERNEL: &'static str = "urelu";
const V: Self = Relu;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.max(bf16::ZERO)
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.max(f16::ZERO)
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.max(0f32)
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.max(0f64)
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}