diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 2c708389..2a393f5f 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -1,4 +1,4 @@ -use crate::storage::{BinaryOp, UnaryOp}; +use crate::op::{BinaryOp, UnaryOp}; use crate::{DType, Error, Result, Shape, StridedIndex}; use gemm::{gemm, Parallelism}; diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index ce0e803d..378db1ce 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -164,7 +164,7 @@ impl CudaStorage { } } - pub(crate) fn unary_impl( + pub(crate) fn unary_impl( &self, shape: &Shape, stride: &[usize], @@ -198,7 +198,7 @@ impl CudaStorage { } } - pub(crate) fn binary_impl( + pub(crate) fn binary_impl( &self, rhs: &Self, shape: &Shape, diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index d5e0ae63..63b55bac 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -54,15 +54,11 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn unary_impl( - &self, - _: &Shape, - _: &[usize], - ) -> Result { + pub(crate) fn unary_impl(&self, _: &Shape, _: &[usize]) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn binary_impl( + pub(crate) fn binary_impl( &self, _: &Self, _: &Shape, diff --git a/src/op.rs b/src/op.rs index 157ce3b3..cbf0789f 100644 --- a/src/op.rs +++ b/src/op.rs @@ -18,3 +18,115 @@ pub(crate) enum Op { Sqrt(Tensor), // TODO: Support for custom ops. } + +pub(crate) trait UnaryOp { + const NAME: &'static str; + // TODO: These kernels are compatible with arbitrary strides. We should also consider the + // contiguous case separately as it's easy to optimize things out there. + const KERNEL_F32: &'static str; + const KERNEL_F64: &'static str; + fn f32(v1: f32) -> f32; + fn f64(v1: f64) -> f64; +} + +pub(crate) trait BinaryOp { + const NAME: &'static str; + // TODO: These kernels are compatible with arbitrary strides. We should also consider the + // contiguous case separately as it's easy to optimize things out there. + const KERNEL_F32: &'static str; + const KERNEL_F64: &'static str; + fn f32(v1: f32, v2: f32) -> f32; + fn f64(v1: f64, v2: f64) -> f64; +} + +pub(crate) struct Add; +pub(crate) struct Div; +pub(crate) struct Mul; +pub(crate) struct Sub; +pub(crate) struct Neg; +pub(crate) struct Sqr; +pub(crate) struct Sqrt; + +impl BinaryOp for Add { + const NAME: &'static str = "add"; + const KERNEL_F32: &'static str = "badd_f32"; + const KERNEL_F64: &'static str = "badd_f64"; + fn f32(v1: f32, v2: f32) -> f32 { + v1 + v2 + } + fn f64(v1: f64, v2: f64) -> f64 { + v1 + v2 + } +} + +impl BinaryOp for Sub { + const NAME: &'static str = "sub"; + const KERNEL_F32: &'static str = "bsub_f32"; + const KERNEL_F64: &'static str = "bsub_f64"; + fn f32(v1: f32, v2: f32) -> f32 { + v1 - v2 + } + fn f64(v1: f64, v2: f64) -> f64 { + v1 - v2 + } +} + +impl BinaryOp for Mul { + const NAME: &'static str = "mul"; + const KERNEL_F32: &'static str = "bmul_f32"; + const KERNEL_F64: &'static str = "bmul_f64"; + fn f32(v1: f32, v2: f32) -> f32 { + v1 * v2 + } + fn f64(v1: f64, v2: f64) -> f64 { + v1 * v2 + } +} + +impl BinaryOp for Div { + const NAME: &'static str = "div"; + const KERNEL_F32: &'static str = "bdiv_f32"; + const KERNEL_F64: &'static str = "bdiv_f64"; + fn f32(v1: f32, v2: f32) -> f32 { + v1 / v2 + } + fn f64(v1: f64, v2: f64) -> f64 { + v1 / v2 + } +} + +impl UnaryOp for Neg { + const NAME: &'static str = "neg"; + fn f32(v1: f32) -> f32 { + -v1 + } + fn f64(v1: f64) -> f64 { + -v1 + } + const KERNEL_F32: &'static str = "uneg_f32"; + const KERNEL_F64: &'static str = "uneg_f64"; +} + +impl UnaryOp for Sqr { + const NAME: &'static str = "sqr"; + fn f32(v1: f32) -> f32 { + v1 * v1 + } + fn f64(v1: f64) -> f64 { + v1 * v1 + } + const KERNEL_F32: &'static str = "usqr_f32"; + const KERNEL_F64: &'static str = "usqr_f64"; +} + +impl UnaryOp for Sqrt { + const NAME: &'static str = "sqrt"; + fn f32(v1: f32) -> f32 { + v1.sqrt() + } + fn f64(v1: f64) -> f64 { + v1.sqrt() + } + const KERNEL_F32: &'static str = "usqrt_f32"; + const KERNEL_F64: &'static str = "usqrt_f64"; +} diff --git a/src/storage.rs b/src/storage.rs index e96f4706..00746089 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,4 +1,4 @@ -use crate::{CpuStorage, CudaStorage, DType, Device, Error, Result, Shape}; +use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape}; #[derive(Debug, Clone)] pub enum Storage { @@ -6,118 +6,6 @@ pub enum Storage { Cuda(CudaStorage), } -pub(crate) trait UnaryOp { - const NAME: &'static str; - // TODO: These kernels are compatible with arbitrary strides. We should also consider the - // contiguous case separately as it's easy to optimize things out there. - const KERNEL_F32: &'static str; - const KERNEL_F64: &'static str; - fn f32(v1: f32) -> f32; - fn f64(v1: f64) -> f64; -} - -pub(crate) trait BinaryOp { - const NAME: &'static str; - // TODO: These kernels are compatible with arbitrary strides. We should also consider the - // contiguous case separately as it's easy to optimize things out there. - const KERNEL_F32: &'static str; - const KERNEL_F64: &'static str; - fn f32(v1: f32, v2: f32) -> f32; - fn f64(v1: f64, v2: f64) -> f64; -} - -struct Add; -struct Div; -struct Mul; -struct Sub; -struct Neg; -struct Sqr; -struct Sqrt; - -impl BinaryOp for Add { - const NAME: &'static str = "add"; - const KERNEL_F32: &'static str = "badd_f32"; - const KERNEL_F64: &'static str = "badd_f64"; - fn f32(v1: f32, v2: f32) -> f32 { - v1 + v2 - } - fn f64(v1: f64, v2: f64) -> f64 { - v1 + v2 - } -} - -impl BinaryOp for Sub { - const NAME: &'static str = "sub"; - const KERNEL_F32: &'static str = "bsub_f32"; - const KERNEL_F64: &'static str = "bsub_f64"; - fn f32(v1: f32, v2: f32) -> f32 { - v1 - v2 - } - fn f64(v1: f64, v2: f64) -> f64 { - v1 - v2 - } -} - -impl BinaryOp for Mul { - const NAME: &'static str = "mul"; - const KERNEL_F32: &'static str = "bmul_f32"; - const KERNEL_F64: &'static str = "bmul_f64"; - fn f32(v1: f32, v2: f32) -> f32 { - v1 * v2 - } - fn f64(v1: f64, v2: f64) -> f64 { - v1 * v2 - } -} - -impl BinaryOp for Div { - const NAME: &'static str = "div"; - const KERNEL_F32: &'static str = "bdiv_f32"; - const KERNEL_F64: &'static str = "bdiv_f64"; - fn f32(v1: f32, v2: f32) -> f32 { - v1 / v2 - } - fn f64(v1: f64, v2: f64) -> f64 { - v1 / v2 - } -} - -impl UnaryOp for Neg { - const NAME: &'static str = "neg"; - fn f32(v1: f32) -> f32 { - -v1 - } - fn f64(v1: f64) -> f64 { - -v1 - } - const KERNEL_F32: &'static str = "uneg_f32"; - const KERNEL_F64: &'static str = "uneg_f64"; -} - -impl UnaryOp for Sqr { - const NAME: &'static str = "sqr"; - fn f32(v1: f32) -> f32 { - v1 * v1 - } - fn f64(v1: f64) -> f64 { - v1 * v1 - } - const KERNEL_F32: &'static str = "usqr_f32"; - const KERNEL_F64: &'static str = "usqr_f64"; -} - -impl UnaryOp for Sqrt { - const NAME: &'static str = "sqrt"; - fn f32(v1: f32) -> f32 { - v1.sqrt() - } - fn f64(v1: f64) -> f64 { - v1.sqrt() - } - const KERNEL_F32: &'static str = "usqrt_f32"; - const KERNEL_F64: &'static str = "usqrt_f64"; -} - impl Storage { pub fn device(&self) -> Device { match self { @@ -173,7 +61,11 @@ impl Storage { } } - fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { + pub(crate) fn unary_impl( + &self, + shape: &Shape, + stride: &[usize], + ) -> Result { // TODO: Different code path for the contiguous case? match self { Storage::Cpu(storage) => { @@ -188,7 +80,7 @@ impl Storage { } // TODO: Support broadcasting? - fn binary_impl( + pub(crate) fn binary_impl( &self, rhs: &Self, shape: &Shape, @@ -218,58 +110,6 @@ impl Storage { } } - pub(crate) fn add_impl( - &self, - rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], - ) -> Result { - self.binary_impl::(rhs, shape, lhs_stride, rhs_stride) - } - - pub(crate) fn sub_impl( - &self, - rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], - ) -> Result { - self.binary_impl::(rhs, shape, lhs_stride, rhs_stride) - } - - pub(crate) fn mul_impl( - &self, - rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], - ) -> Result { - self.binary_impl::(rhs, shape, lhs_stride, rhs_stride) - } - - pub(crate) fn div_impl( - &self, - rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], - ) -> Result { - self.binary_impl::
(rhs, shape, lhs_stride, rhs_stride) - } - - pub(crate) fn neg_impl(&self, shape: &Shape, stride: &[usize]) -> Result { - self.unary_impl::(shape, stride) - } - - pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result { - self.unary_impl::(shape, stride) - } - - pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result { - self.unary_impl::(shape, stride) - } - pub(crate) fn matmul_impl( &self, rhs: &Self, diff --git a/src/tensor.rs b/src/tensor.rs index 66807594..07744d70 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -43,10 +43,12 @@ impl std::fmt::Debug for Tensor { } macro_rules! unary_op { - ($fn_name:ident, $op_name:ident, $impl_name:ident) => { + ($fn_name:ident, $op_name:ident) => { pub fn $fn_name(&self) -> Result { let shape = self.shape(); - let storage = self.storage.$impl_name(self.shape(), self.stride())?; + let storage = self + .storage + .unary_impl::(self.shape(), self.stride())?; let tensor_ = Tensor_ { id: TensorId::new(), storage, @@ -61,12 +63,15 @@ macro_rules! unary_op { } macro_rules! binary_op { - ($fn_name:ident, $op_name:ident, $impl_name:ident) => { + ($fn_name:ident, $op_name:ident) => { pub fn $fn_name(&self, rhs: &Self) -> Result { let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?; - let storage = - self.storage - .$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?; + let storage = self.storage.binary_impl::( + &rhs.storage, + shape, + self.stride(), + rhs.stride(), + )?; let tensor_ = Tensor_ { id: TensorId::new(), storage, @@ -211,14 +216,14 @@ impl Tensor { // TODO: Also make an inplace version or a pre-allocated? This could be tricky // if this can create cycles in the compute graph. - binary_op!(add, Add, add_impl); - binary_op!(mul, Mul, mul_impl); - binary_op!(sub, Sub, sub_impl); - binary_op!(div, Div, div_impl); + binary_op!(add, Add); + binary_op!(mul, Mul); + binary_op!(sub, Sub); + binary_op!(div, Div); - unary_op!(neg, Neg, neg_impl); - unary_op!(sqr, Sqr, sqr_impl); - unary_op!(sqrt, Sqrt, sqrt_impl); + unary_op!(neg, Neg); + unary_op!(sqr, Sqr); + unary_op!(sqrt, Sqrt); pub fn to_scalar(&self) -> Result { if self.rank() != 0 { return Err(Error::UnexpectedNumberOfDims {