Custom ops with a single argument (#214)

* Add the CustomOp1 trait.

* Add an example of custom op.

* Polish the custom op example.

* Add some backward pass test for custom ops.
This commit is contained in:
Laurent Mazare
2023-07-21 16:18:05 +02:00
committed by GitHub
parent b02229ce92
commit a6bcdfb269
8 changed files with 241 additions and 18 deletions

View File

@ -1,5 +1,5 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, ReduceOp};
use crate::op::{self, CmpOp, CustomOp1, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
@ -147,6 +147,19 @@ impl Storage {
}
}
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
match self {
Storage::Cpu(storage) => {
let (storage, shape) = c.cpu_fwd(storage, l)?;
Ok((Self::Cpu(storage), shape))
}
Self::Cuda(storage) => {
let (storage, shape) = c.cuda_fwd(storage, l)?;
Ok((Self::Cuda(storage), shape))
}
}
}
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {