Add binary and ternary custom ops. (#217)

This commit is contained in:
Laurent Mazare
2023-07-21 18:29:50 +02:00
committed by GitHub
parent 4a100875bf
commit 5cc843550d
5 changed files with 209 additions and 10 deletions

View File

@ -38,7 +38,7 @@ impl Tensor {
nodes
} else if let Some(op) = node.op() {
match op {
Op::WhereCond(t1, t2, t3) => {
Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t2, nodes, already_seen);
@ -52,6 +52,7 @@ impl Tensor {
kernel: rhs,
..
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Embedding(lhs, rhs)
@ -321,10 +322,38 @@ impl Tensor {
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::CustomOp1(arg, c) => {
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
let sum_grad = grads.or_insert(arg)?;
let arg_grad = c.bwd(arg, node, &grad)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
}
Op::CustomOp2(arg1, arg2, c) => {
let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
if let Some(arg_grad1) = arg_grad1 {
let sum_grad = grads.or_insert(arg1)?;
*sum_grad = sum_grad.add(&arg_grad1)?
}
if let Some(arg_grad2) = arg_grad2 {
let sum_grad = grads.or_insert(arg2)?;
*sum_grad = sum_grad.add(&arg_grad2)?
}
}
Op::CustomOp3(arg1, arg2, arg3, c) => {
let (arg_grad1, arg_grad2, arg_grad3) =
c.bwd(arg1, arg2, arg3, node, &grad)?;
if let Some(arg_grad1) = arg_grad1 {
let sum_grad = grads.or_insert(arg1)?;
*sum_grad = sum_grad.add(&arg_grad1)?
}
if let Some(arg_grad2) = arg_grad2 {
let sum_grad = grads.or_insert(arg2)?;
*sum_grad = sum_grad.add(&arg_grad2)?
}
if let Some(arg_grad3) = arg_grad3 {
let sum_grad = grads.or_insert(arg3)?;
*sum_grad = sum_grad.add(&arg_grad3)?
}
}
Op::Unary(arg, UnaryOp::Sqr) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.or_insert(arg)?;

View File

@ -94,6 +94,8 @@ pub(crate) enum Op {
Transpose(Tensor, usize, usize),
Elu(Tensor, f64),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
}
/// Unary ops that can be defined in user-land.
@ -116,7 +118,88 @@ pub trait CustomOp1: Send + Sync {
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Tensor> {
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp2: Send + Sync {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp3: Send + Sync {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_arg3: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}

View File

@ -1,5 +1,5 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, CustomOp1, ReduceOp};
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
@ -149,7 +149,7 @@ impl Storage {
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
match self {
Storage::Cpu(storage) => {
Self::Cpu(storage) => {
let (storage, shape) = c.cpu_fwd(storage, l)?;
Ok((Self::Cpu(storage), shape))
}
@ -160,6 +160,51 @@ impl Storage {
}
}
pub(crate) fn custom_op2(
&self,
l1: &Layout,
t2: &Self,
l2: &Layout,
c: &dyn CustomOp2,
) -> Result<(Self, Shape)> {
self.same_device(t2, c.name())?;
match (self, t2) {
(Self::Cpu(s1), Self::Cpu(s2)) => {
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
Ok((Self::Cpu(s), shape))
}
(Self::Cuda(s1), Self::Cuda(s2)) => {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
Ok((Self::Cuda(s), shape))
}
_ => unreachable!(),
}
}
pub(crate) fn custom_op3(
&self,
l1: &Layout,
t2: &Self,
l2: &Layout,
t3: &Self,
l3: &Layout,
c: &dyn CustomOp3,
) -> Result<(Self, Shape)> {
self.same_device(t2, c.name())?;
self.same_device(t3, c.name())?;
match (self, t2, t3) {
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cpu(s), shape))
}
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cuda(s), shape))
}
_ => unreachable!(),
}
}
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {

View File

@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOp, CmpOp, CustomOp1, Op, ReduceOp, UnaryOp};
use crate::op::{BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp};
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@ -1705,6 +1705,48 @@ impl Tensor {
pub fn custom_op1<C: 'static + CustomOp1>(&self, c: C) -> Result<Self> {
self.custom_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
pub fn custom_op2_arc(&self, rhs: &Self, c: Arc<Box<dyn CustomOp2>>) -> Result<Self> {
let (storage, shape) = self.storage().custom_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
c.as_ref().as_ref(),
)?;
let op = if self.track_op() {
Some(Op::CustomOp2(self.clone(), rhs.clone(), c))
} else {
None
};
Ok(from_storage(storage, shape, op, false))
}
pub fn custom_op2<C: 'static + CustomOp2>(&self, r: &Self, c: C) -> Result<Self> {
self.custom_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
pub fn custom_op3_arc(&self, t2: &Self, t3: &Self, c: Arc<Box<dyn CustomOp3>>) -> Result<Self> {
let (storage, shape) = self.storage().custom_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c.as_ref().as_ref(),
)?;
let op = if self.track_op() {
Some(Op::CustomOp3(self.clone(), t2.clone(), t3.clone(), c))
} else {
None
};
Ok(from_storage(storage, shape, op, false))
}
pub fn custom_op3<C: 'static + CustomOp3>(&self, t2: &Self, t3: &Self, c: C) -> Result<Self> {
self.custom_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
}
macro_rules! bin_trait {

View File

@ -94,10 +94,10 @@ impl CustomOp1 for EluWithBackward {
self.0.cpu_fwd(s, l)
}
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Tensor> {
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
let alpha = self.0.alpha;
let bwd = arg.custom_op1(EluBackward { alpha })?;
grad_res.mul(&bwd)
Ok(Some(grad_res.mul(&bwd)?))
}
}