Custom ops with a single argument (#214)

* Add the CustomOp1 trait.

* Add an example of custom op.

* Polish the custom op example.

* Add some backward pass test for custom ops.
This commit is contained in:
Laurent Mazare
2023-07-21 16:18:05 +02:00
committed by GitHub
parent b02229ce92
commit a6bcdfb269
8 changed files with 241 additions and 18 deletions

View File

@ -1,4 +1,4 @@
use crate::Tensor;
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
use num_traits::float::Float;
@ -93,10 +93,35 @@ pub(crate) enum Op {
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Elu(Tensor, f64),
// TODO: Support for custom ops.
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
}
pub(crate) trait UnaryOpT {
/// Unary ops that can be defined in user-land.
pub trait CustomOp1: Send + Sync {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _: &CudaStorage, _: &Layout) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Tensor> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait UnaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
@ -119,7 +144,7 @@ pub(crate) trait UnaryOpT {
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
}
pub(crate) trait BinaryOpT {
pub trait BinaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;