Add binary and ternary custom ops. (#217)

This commit is contained in:
Laurent Mazare
2023-07-21 18:29:50 +02:00
committed by GitHub
parent 4a100875bf
commit 5cc843550d
5 changed files with 209 additions and 10 deletions

View File

@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOp, CmpOp, CustomOp1, Op, ReduceOp, UnaryOp};
use crate::op::{BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp};
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@ -1705,6 +1705,48 @@ impl Tensor {
pub fn custom_op1<C: 'static + CustomOp1>(&self, c: C) -> Result<Self> {
self.custom_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
pub fn custom_op2_arc(&self, rhs: &Self, c: Arc<Box<dyn CustomOp2>>) -> Result<Self> {
let (storage, shape) = self.storage().custom_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
c.as_ref().as_ref(),
)?;
let op = if self.track_op() {
Some(Op::CustomOp2(self.clone(), rhs.clone(), c))
} else {
None
};
Ok(from_storage(storage, shape, op, false))
}
pub fn custom_op2<C: 'static + CustomOp2>(&self, r: &Self, c: C) -> Result<Self> {
self.custom_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
pub fn custom_op3_arc(&self, t2: &Self, t3: &Self, c: Arc<Box<dyn CustomOp3>>) -> Result<Self> {
let (storage, shape) = self.storage().custom_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c.as_ref().as_ref(),
)?;
let op = if self.track_op() {
Some(Op::CustomOp3(self.clone(), t2.clone(), t3.clone(), c))
} else {
None
};
Ok(from_storage(storage, shape, op, false))
}
pub fn custom_op3<C: 'static + CustomOp3>(&self, t2: &Self, t3: &Self, c: C) -> Result<Self> {
self.custom_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
}
macro_rules! bin_trait {