Add binary and ternary custom ops. (#217)

This commit is contained in:
Laurent Mazare
2023-07-21 18:29:50 +02:00
committed by GitHub
parent 4a100875bf
commit 5cc843550d
5 changed files with 209 additions and 10 deletions

View File

@ -1,5 +1,5 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, CustomOp1, ReduceOp};
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
@ -149,7 +149,7 @@ impl Storage {
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
match self {
Storage::Cpu(storage) => {
Self::Cpu(storage) => {
let (storage, shape) = c.cpu_fwd(storage, l)?;
Ok((Self::Cpu(storage), shape))
}
@ -160,6 +160,51 @@ impl Storage {
}
}
pub(crate) fn custom_op2(
&self,
l1: &Layout,
t2: &Self,
l2: &Layout,
c: &dyn CustomOp2,
) -> Result<(Self, Shape)> {
self.same_device(t2, c.name())?;
match (self, t2) {
(Self::Cpu(s1), Self::Cpu(s2)) => {
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
Ok((Self::Cpu(s), shape))
}
(Self::Cuda(s1), Self::Cuda(s2)) => {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
Ok((Self::Cuda(s), shape))
}
_ => unreachable!(),
}
}
pub(crate) fn custom_op3(
&self,
l1: &Layout,
t2: &Self,
l2: &Layout,
t3: &Self,
l3: &Layout,
c: &dyn CustomOp3,
) -> Result<(Self, Shape)> {
self.same_device(t2, c.name())?;
self.same_device(t3, c.name())?;
match (self, t2, t3) {
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cpu(s), shape))
}
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cuda(s), shape))
}
_ => unreachable!(),
}
}
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {