Remove one level of indirection for the binary and unary ops.

This commit is contained in:
laurent
2023-06-22 15:20:51 +01:00
parent 5276755fb3
commit 836ad5f76c
6 changed files with 142 additions and 189 deletions

View File

@ -43,10 +43,12 @@ impl std::fmt::Debug for Tensor {
}
macro_rules! unary_op {
($fn_name:ident, $op_name:ident, $impl_name:ident) => {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self) -> Result<Self> {
let shape = self.shape();
let storage = self.storage.$impl_name(self.shape(), self.stride())?;
let storage = self
.storage
.unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?;
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
@ -61,12 +63,15 @@ macro_rules! unary_op {
}
macro_rules! binary_op {
($fn_name:ident, $op_name:ident, $impl_name:ident) => {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
let storage =
self.storage
.$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?;
let storage = self.storage.binary_impl::<crate::op::$op_name>(
&rhs.storage,
shape,
self.stride(),
rhs.stride(),
)?;
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
@ -211,14 +216,14 @@ impl Tensor {
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
// if this can create cycles in the compute graph.
binary_op!(add, Add, add_impl);
binary_op!(mul, Mul, mul_impl);
binary_op!(sub, Sub, sub_impl);
binary_op!(div, Div, div_impl);
binary_op!(add, Add);
binary_op!(mul, Mul);
binary_op!(sub, Sub);
binary_op!(div, Div);
unary_op!(neg, Neg, neg_impl);
unary_op!(sqr, Sqr, sqr_impl);
unary_op!(sqrt, Sqrt, sqrt_impl);
unary_op!(neg, Neg);
unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt);
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
if self.rank() != 0 {
return Err(Error::UnexpectedNumberOfDims {