Add a couple operators.

This commit is contained in:
laurent
2023-06-20 22:32:11 +01:00
parent f1f372b13e
commit 78bac0ed32
3 changed files with 93 additions and 9 deletions

View File

@ -95,7 +95,10 @@ trait BinaryOp {
}
struct Add;
struct Div;
struct Mul;
struct Sub;
struct Neg;
struct Sqr;
struct Sqrt;
@ -109,6 +112,16 @@ impl BinaryOp for Add {
}
}
impl BinaryOp for Sub {
const NAME: &'static str = "sub";
fn f32(v1: f32, v2: f32) -> f32 {
v1 - v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 - v2
}
}
impl BinaryOp for Mul {
const NAME: &'static str = "mul";
fn f32(v1: f32, v2: f32) -> f32 {
@ -119,6 +132,26 @@ impl BinaryOp for Mul {
}
}
impl BinaryOp for Div {
const NAME: &'static str = "div";
fn f32(v1: f32, v2: f32) -> f32 {
v1 / v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 / v2
}
}
impl UnaryOp for Neg {
const NAME: &'static str = "neg";
fn f32(v1: f32) -> f32 {
-v1
}
fn f64(v1: f64) -> f64 {
-v1
}
}
impl UnaryOp for Sqr {
const NAME: &'static str = "sqr";
fn f32(v1: f32) -> f32 {
@ -272,6 +305,16 @@ impl Storage {
self.binary_impl::<Add>(rhs, shape, lhs_stride, rhs_stride)
}
pub(crate) fn sub_impl(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.binary_impl::<Sub>(rhs, shape, lhs_stride, rhs_stride)
}
pub(crate) fn mul_impl(
&self,
rhs: &Self,
@ -282,6 +325,20 @@ impl Storage {
self.binary_impl::<Mul>(rhs, shape, lhs_stride, rhs_stride)
}
pub(crate) fn div_impl(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.binary_impl::<Div>(rhs, shape, lhs_stride, rhs_stride)
}
pub(crate) fn neg_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Neg>(shape, stride)
}
pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Sqr>(shape, stride)
}