Add the kernels.

This commit is contained in:
laurent
2023-06-30 10:26:56 +01:00
parent a7b16cbb98
commit 8ad47907f3
11 changed files with 117 additions and 3 deletions

View File

@ -49,6 +49,7 @@ pub(crate) trait UnaryOp {
fn f16(v1: f16) -> f16;
fn f32(v1: f32) -> f32;
fn f64(v1: f64) -> f64;
fn u8(v1: u8) -> u8;
fn u32(v1: u32) -> u32;
}
@ -60,6 +61,7 @@ pub(crate) trait BinaryOp {
fn f16(v1: f16, v2: f16) -> f16;
fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64;
fn u8(v1: u8, v2: u8) -> u8;
fn u32(v1: u32, v2: u32) -> u32;
}
@ -96,6 +98,9 @@ macro_rules! bin_op {
fn f64(v1: f64, v2: f64) -> f64 {
$e(v1, v2)
}
fn u8(v1: u8, v2: u8) -> u8 {
$e(v1, v2)
}
fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2)
}
@ -126,6 +131,9 @@ macro_rules! unary_op {
fn f64($a: f64) -> f64 {
$e
}
fn u8(_: u8) -> u8 {
todo!("no unary function for u8")
}
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
@ -177,6 +185,9 @@ impl UnaryOp for Gelu {
* (1.0
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
fn u8(_: u8) -> u8 {
0
}
fn u32(_: u32) -> u32 {
0
}
@ -199,6 +210,9 @@ impl UnaryOp for Relu {
fn f64(v: f64) -> f64 {
v.max(0f64)
}
fn u8(v: u8) -> u8 {
v
}
fn u32(v: u32) -> u32 {
v
}