Optimize the cpu backend for the contiguous cases.

This commit is contained in:
laurent
2023-06-23 18:08:55 +01:00
parent 132859df75
commit 4f9f14a06b
2 changed files with 31 additions and 13 deletions

View File

@ -39,6 +39,7 @@ pub(crate) trait UnaryOp {
const KERNEL_F64: &'static str;
fn f32(v1: f32) -> f32;
fn f64(v1: f64) -> f64;
fn u32(v1: u32) -> u32;
}
pub(crate) trait BinaryOp {
@ -134,6 +135,9 @@ impl UnaryOp for Neg {
fn f64(v1: f64) -> f64 {
-v1
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_F32: &'static str = "uneg_f32";
const KERNEL_F64: &'static str = "uneg_f64";
}
@ -146,6 +150,9 @@ impl UnaryOp for Sqr {
fn f64(v1: f64) -> f64 {
v1 * v1
}
fn u32(v: u32) -> u32 {
v * v
}
const KERNEL_F32: &'static str = "usqr_f32";
const KERNEL_F64: &'static str = "usqr_f64";
}
@ -158,6 +165,9 @@ impl UnaryOp for Sqrt {
fn f64(v1: f64) -> f64 {
v1.sqrt()
}
fn u32(v: u32) -> u32 {
(v as f64).sqrt() as u32
}
const KERNEL_F32: &'static str = "usqrt_f32";
const KERNEL_F64: &'static str = "usqrt_f64";
}
@ -184,6 +194,9 @@ impl UnaryOp for Gelu {
fn f64(v1: f64) -> f64 {
gelu_f64(v1)
}
fn u32(v1: u32) -> u32 {
gelu_f64(v1 as f64) as u32
}
const KERNEL_F32: &'static str = "gelu_f32";
const KERNEL_F64: &'static str = "gelu_f64";
}