Remove one level of indirection for the binary and unary ops.

This commit is contained in:
laurent
2023-06-22 15:20:51 +01:00
parent 5276755fb3
commit 836ad5f76c
6 changed files with 142 additions and 189 deletions

View File

@ -1,4 +1,4 @@
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
#[derive(Debug, Clone)]
pub enum Storage {
@ -6,118 +6,6 @@ pub enum Storage {
Cuda(CudaStorage),
}
pub(crate) trait UnaryOp {
const NAME: &'static str;
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
// contiguous case separately as it's easy to optimize things out there.
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
fn f32(v1: f32) -> f32;
fn f64(v1: f64) -> f64;
}
pub(crate) trait BinaryOp {
const NAME: &'static str;
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
// contiguous case separately as it's easy to optimize things out there.
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64;
}
struct Add;
struct Div;
struct Mul;
struct Sub;
struct Neg;
struct Sqr;
struct Sqrt;
impl BinaryOp for Add {
const NAME: &'static str = "add";
const KERNEL_F32: &'static str = "badd_f32";
const KERNEL_F64: &'static str = "badd_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 + v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 + v2
}
}
impl BinaryOp for Sub {
const NAME: &'static str = "sub";
const KERNEL_F32: &'static str = "bsub_f32";
const KERNEL_F64: &'static str = "bsub_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 - v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 - v2
}
}
impl BinaryOp for Mul {
const NAME: &'static str = "mul";
const KERNEL_F32: &'static str = "bmul_f32";
const KERNEL_F64: &'static str = "bmul_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 * v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 * v2
}
}
impl BinaryOp for Div {
const NAME: &'static str = "div";
const KERNEL_F32: &'static str = "bdiv_f32";
const KERNEL_F64: &'static str = "bdiv_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 / v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 / v2
}
}
impl UnaryOp for Neg {
const NAME: &'static str = "neg";
fn f32(v1: f32) -> f32 {
-v1
}
fn f64(v1: f64) -> f64 {
-v1
}
const KERNEL_F32: &'static str = "uneg_f32";
const KERNEL_F64: &'static str = "uneg_f64";
}
impl UnaryOp for Sqr {
const NAME: &'static str = "sqr";
fn f32(v1: f32) -> f32 {
v1 * v1
}
fn f64(v1: f64) -> f64 {
v1 * v1
}
const KERNEL_F32: &'static str = "usqr_f32";
const KERNEL_F64: &'static str = "usqr_f64";
}
impl UnaryOp for Sqrt {
const NAME: &'static str = "sqrt";
fn f32(v1: f32) -> f32 {
v1.sqrt()
}
fn f64(v1: f64) -> f64 {
v1.sqrt()
}
const KERNEL_F32: &'static str = "usqrt_f32";
const KERNEL_F64: &'static str = "usqrt_f64";
}
impl Storage {
pub fn device(&self) -> Device {
match self {
@ -173,7 +61,11 @@ impl Storage {
}
}
fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
pub(crate) fn unary_impl<B: op::UnaryOp>(
&self,
shape: &Shape,
stride: &[usize],
) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {
Storage::Cpu(storage) => {
@ -188,7 +80,7 @@ impl Storage {
}
// TODO: Support broadcasting?
fn binary_impl<B: BinaryOp>(
pub(crate) fn binary_impl<B: op::BinaryOp>(
&self,
rhs: &Self,
shape: &Shape,
@ -218,58 +110,6 @@ impl Storage {
}
}
pub(crate) fn add_impl(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.binary_impl::<Add>(rhs, shape, lhs_stride, rhs_stride)
}
pub(crate) fn sub_impl(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.binary_impl::<Sub>(rhs, shape, lhs_stride, rhs_stride)
}
pub(crate) fn mul_impl(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.binary_impl::<Mul>(rhs, shape, lhs_stride, rhs_stride)
}
pub(crate) fn div_impl(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.binary_impl::<Div>(rhs, shape, lhs_stride, rhs_stride)
}
pub(crate) fn neg_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Neg>(shape, stride)
}
pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Sqr>(shape, stride)
}
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Sqrt>(shape, stride)
}
pub(crate) fn matmul_impl(
&self,
rhs: &Self,