Add a couple unitary ops.

This commit is contained in:
laurent
2023-06-23 20:19:20 +01:00
parent fe75a01188
commit 8ed350dc94
3 changed files with 127 additions and 0 deletions

View File

@ -21,6 +21,11 @@ pub(crate) enum Op {
mul: f64,
add: f64,
},
Exp(Tensor),
Log(Tensor),
Sin(Tensor),
Cos(Tensor),
Abs(Tensor),
Neg(Tensor),
Reshape(Tensor),
#[allow(dead_code)]
@ -60,6 +65,11 @@ pub(crate) struct Add;
pub(crate) struct Div;
pub(crate) struct Mul;
pub(crate) struct Sub;
pub(crate) struct Exp;
pub(crate) struct Log;
pub(crate) struct Sin;
pub(crate) struct Cos;
pub(crate) struct Abs;
pub(crate) struct Neg;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
@ -129,6 +139,81 @@ impl BinaryOp for Div {
}
}
impl UnaryOp for Exp {
const NAME: &'static str = "exp";
fn f32(v1: f32) -> f32 {
v1.exp()
}
fn f64(v1: f64) -> f64 {
v1.exp()
}
fn u32(v1: u32) -> u32 {
(v1 as f64).exp() as u32
}
const KERNEL_F32: &'static str = "uexp_f32";
const KERNEL_F64: &'static str = "uexp_f64";
}
impl UnaryOp for Log {
const NAME: &'static str = "log";
fn f32(v1: f32) -> f32 {
v1.ln()
}
fn f64(v1: f64) -> f64 {
v1.ln()
}
fn u32(v1: u32) -> u32 {
(v1 as f64).ln() as u32
}
const KERNEL_F32: &'static str = "ulog_f32";
const KERNEL_F64: &'static str = "ulog_f64";
}
impl UnaryOp for Sin {
const NAME: &'static str = "sin";
fn f32(v1: f32) -> f32 {
v1.sin()
}
fn f64(v1: f64) -> f64 {
v1.sin()
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_F32: &'static str = "usin_f32";
const KERNEL_F64: &'static str = "usin_f64";
}
impl UnaryOp for Cos {
const NAME: &'static str = "cos";
fn f32(v1: f32) -> f32 {
v1.cos()
}
fn f64(v1: f64) -> f64 {
v1.cos()
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_F32: &'static str = "ucos_f32";
const KERNEL_F64: &'static str = "ucos_f64";
}
impl UnaryOp for Abs {
const NAME: &'static str = "abs";
fn f32(v1: f32) -> f32 {
v1.abs()
}
fn f64(v1: f64) -> f64 {
v1.abs()
}
fn u32(v1: u32) -> u32 {
v1
}
const KERNEL_F32: &'static str = "uabs_f32";
const KERNEL_F64: &'static str = "uabs_f64";
}
impl UnaryOp for Neg {
const NAME: &'static str = "neg";
fn f32(v1: f32) -> f32 {