Add more binary kernels.

This commit is contained in:
laurent
2023-06-22 14:07:02 +01:00
parent 97fe1fac85
commit b8f514d9c6
9 changed files with 54 additions and 87 deletions

View File

@ -14,15 +14,10 @@ pub(crate) trait UnaryOp {
pub(crate) trait BinaryOp {
const NAME: &'static str;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64;
fn cuda_impl(
lhs: &CudaStorage,
rhs: &CudaStorage,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<CudaStorage>;
}
struct Add;
@ -35,78 +30,50 @@ struct Sqrt;
impl BinaryOp for Add {
const NAME: &'static str = "add";
const KERNEL_F32: &'static str = "badd_f32";
const KERNEL_F64: &'static str = "badd_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 + v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 + v2
}
fn cuda_impl(
lhs: &CudaStorage,
rhs: &CudaStorage,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<CudaStorage> {
Ok(lhs.add_impl(rhs, shape, lhs_stride, rhs_stride)?)
}
}
impl BinaryOp for Sub {
const NAME: &'static str = "sub";
const KERNEL_F32: &'static str = "bsub_f32";
const KERNEL_F64: &'static str = "bsub_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 - v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 - v2
}
fn cuda_impl(
_: &CudaStorage,
_: &CudaStorage,
_: &Shape,
_: &[usize],
_: &[usize],
) -> Result<CudaStorage> {
todo!()
}
}
impl BinaryOp for Mul {
const NAME: &'static str = "mul";
const KERNEL_F32: &'static str = "bmul_f32";
const KERNEL_F64: &'static str = "bmul_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 * v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 * v2
}
fn cuda_impl(
_: &CudaStorage,
_: &CudaStorage,
_: &Shape,
_: &[usize],
_: &[usize],
) -> Result<CudaStorage> {
todo!()
}
}
impl BinaryOp for Div {
const NAME: &'static str = "div";
const KERNEL_F32: &'static str = "bdiv_f32";
const KERNEL_F64: &'static str = "bdiv_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 / v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 / v2
}
fn cuda_impl(
_: &CudaStorage,
_: &CudaStorage,
_: &Shape,
_: &[usize],
_: &[usize],
) -> Result<CudaStorage> {
todo!()
}
}
impl UnaryOp for Neg {
@ -221,7 +188,7 @@ impl Storage {
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = B::cuda_impl(lhs, rhs, shape, lhs_stride, rhs_stride)?;
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => {