Relax the requirements on CustomOp. (#486)

* Relax the requirements on CustomOp.

* Simplify the custom-ops when no backward is required.
This commit is contained in:
Laurent Mazare
2023-08-17 11:12:05 +01:00
committed by GitHub
parent d32e8199cd
commit 03be33eea4
8 changed files with 81 additions and 31 deletions

View File

@ -118,13 +118,22 @@ pub enum Op {
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Elu(Tensor, f64),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
CustomOp2(
Tensor,
Tensor,
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
),
CustomOp3(
Tensor,
Tensor,
Tensor,
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
),
}
/// Unary ops that can be defined in user-land.
pub trait CustomOp1: Send + Sync {
pub trait CustomOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
@ -148,7 +157,7 @@ pub trait CustomOp1: Send + Sync {
}
}
pub trait CustomOp2: Send + Sync {
pub trait CustomOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
@ -186,7 +195,7 @@ pub trait CustomOp2: Send + Sync {
}
}
pub trait CustomOp3: Send + Sync {
pub trait CustomOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,