mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add the kernels.
This commit is contained in:
@ -49,6 +49,7 @@ pub(crate) trait UnaryOp {
|
||||
fn f16(v1: f16) -> f16;
|
||||
fn f32(v1: f32) -> f32;
|
||||
fn f64(v1: f64) -> f64;
|
||||
fn u8(v1: u8) -> u8;
|
||||
fn u32(v1: u32) -> u32;
|
||||
}
|
||||
|
||||
@ -60,6 +61,7 @@ pub(crate) trait BinaryOp {
|
||||
fn f16(v1: f16, v2: f16) -> f16;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
fn u8(v1: u8, v2: u8) -> u8;
|
||||
fn u32(v1: u32, v2: u32) -> u32;
|
||||
}
|
||||
|
||||
@ -96,6 +98,9 @@ macro_rules! bin_op {
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
fn u8(v1: u8, v2: u8) -> u8 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
@ -126,6 +131,9 @@ macro_rules! unary_op {
|
||||
fn f64($a: f64) -> f64 {
|
||||
$e
|
||||
}
|
||||
fn u8(_: u8) -> u8 {
|
||||
todo!("no unary function for u8")
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
@ -177,6 +185,9 @@ impl UnaryOp for Gelu {
|
||||
* (1.0
|
||||
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
@ -199,6 +210,9 @@ impl UnaryOp for Relu {
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.max(0f64)
|
||||
}
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
|
Reference in New Issue
Block a user