mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add a couple operators.
This commit is contained in:
@ -95,7 +95,10 @@ trait BinaryOp {
|
||||
}
|
||||
|
||||
struct Add;
|
||||
struct Div;
|
||||
struct Mul;
|
||||
struct Sub;
|
||||
struct Neg;
|
||||
struct Sqr;
|
||||
struct Sqrt;
|
||||
|
||||
@ -109,6 +112,16 @@ impl BinaryOp for Add {
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Sub {
|
||||
const NAME: &'static str = "sub";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 - v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 - v2
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Mul {
|
||||
const NAME: &'static str = "mul";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
@ -119,6 +132,26 @@ impl BinaryOp for Mul {
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Div {
|
||||
const NAME: &'static str = "div";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 / v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 / v2
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Neg {
|
||||
const NAME: &'static str = "neg";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
-v1
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
-v1
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqr {
|
||||
const NAME: &'static str = "sqr";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
@ -272,6 +305,16 @@ impl Storage {
|
||||
self.binary_impl::<Add>(rhs, shape, lhs_stride, rhs_stride)
|
||||
}
|
||||
|
||||
pub(crate) fn sub_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.binary_impl::<Sub>(rhs, shape, lhs_stride, rhs_stride)
|
||||
}
|
||||
|
||||
pub(crate) fn mul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
@ -282,6 +325,20 @@ impl Storage {
|
||||
self.binary_impl::<Mul>(rhs, shape, lhs_stride, rhs_stride)
|
||||
}
|
||||
|
||||
pub(crate) fn div_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.binary_impl::<Div>(rhs, shape, lhs_stride, rhs_stride)
|
||||
}
|
||||
|
||||
pub(crate) fn neg_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
self.unary_impl::<Neg>(shape, stride)
|
||||
}
|
||||
|
||||
pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
self.unary_impl::<Sqr>(shape, stride)
|
||||
}
|
||||
|
Reference in New Issue
Block a user