mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add a couple unitary ops.
This commit is contained in:
85
src/op.rs
85
src/op.rs
@ -21,6 +21,11 @@ pub(crate) enum Op {
|
||||
mul: f64,
|
||||
add: f64,
|
||||
},
|
||||
Exp(Tensor),
|
||||
Log(Tensor),
|
||||
Sin(Tensor),
|
||||
Cos(Tensor),
|
||||
Abs(Tensor),
|
||||
Neg(Tensor),
|
||||
Reshape(Tensor),
|
||||
#[allow(dead_code)]
|
||||
@ -60,6 +65,11 @@ pub(crate) struct Add;
|
||||
pub(crate) struct Div;
|
||||
pub(crate) struct Mul;
|
||||
pub(crate) struct Sub;
|
||||
pub(crate) struct Exp;
|
||||
pub(crate) struct Log;
|
||||
pub(crate) struct Sin;
|
||||
pub(crate) struct Cos;
|
||||
pub(crate) struct Abs;
|
||||
pub(crate) struct Neg;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
@ -129,6 +139,81 @@ impl BinaryOp for Div {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Exp {
|
||||
const NAME: &'static str = "exp";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.exp()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.exp()
|
||||
}
|
||||
fn u32(v1: u32) -> u32 {
|
||||
(v1 as f64).exp() as u32
|
||||
}
|
||||
const KERNEL_F32: &'static str = "uexp_f32";
|
||||
const KERNEL_F64: &'static str = "uexp_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Log {
|
||||
const NAME: &'static str = "log";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.ln()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.ln()
|
||||
}
|
||||
fn u32(v1: u32) -> u32 {
|
||||
(v1 as f64).ln() as u32
|
||||
}
|
||||
const KERNEL_F32: &'static str = "ulog_f32";
|
||||
const KERNEL_F64: &'static str = "ulog_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sin {
|
||||
const NAME: &'static str = "sin";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.sin()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.sin()
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usin_f32";
|
||||
const KERNEL_F64: &'static str = "usin_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Cos {
|
||||
const NAME: &'static str = "cos";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.cos()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.cos()
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
const KERNEL_F32: &'static str = "ucos_f32";
|
||||
const KERNEL_F64: &'static str = "ucos_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.abs()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.abs()
|
||||
}
|
||||
fn u32(v1: u32) -> u32 {
|
||||
v1
|
||||
}
|
||||
const KERNEL_F32: &'static str = "uabs_f32";
|
||||
const KERNEL_F64: &'static str = "uabs_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Neg {
|
||||
const NAME: &'static str = "neg";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
|
Reference in New Issue
Block a user