diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 9ae6c23c..bfbb350e 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -38,7 +38,7 @@ impl Tensor { nodes } else if let Some(op) = node.op() { match op { - Op::WhereCond(t1, t2, t3) => { + Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => { let (tg, nodes) = walk(t1, nodes, already_seen); track_grad |= tg; let (tg, nodes) = walk(t2, nodes, already_seen); @@ -52,6 +52,7 @@ impl Tensor { kernel: rhs, .. } + | Op::CustomOp2(lhs, rhs, _) | Op::Binary(lhs, rhs, _) | Op::IndexSelect(lhs, rhs, _) | Op::Embedding(lhs, rhs) @@ -321,9 +322,37 @@ impl Tensor { Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?, Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?, Op::CustomOp1(arg, c) => { - let sum_grad = grads.or_insert(arg)?; - let arg_grad = c.bwd(arg, node, &grad)?; - *sum_grad = sum_grad.add(&arg_grad)? + if let Some(arg_grad) = c.bwd(arg, node, &grad)? { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + } + Op::CustomOp2(arg1, arg2, c) => { + let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?; + if let Some(arg_grad1) = arg_grad1 { + let sum_grad = grads.or_insert(arg1)?; + *sum_grad = sum_grad.add(&arg_grad1)? + } + if let Some(arg_grad2) = arg_grad2 { + let sum_grad = grads.or_insert(arg2)?; + *sum_grad = sum_grad.add(&arg_grad2)? + } + } + Op::CustomOp3(arg1, arg2, arg3, c) => { + let (arg_grad1, arg_grad2, arg_grad3) = + c.bwd(arg1, arg2, arg3, node, &grad)?; + if let Some(arg_grad1) = arg_grad1 { + let sum_grad = grads.or_insert(arg1)?; + *sum_grad = sum_grad.add(&arg_grad1)? + } + if let Some(arg_grad2) = arg_grad2 { + let sum_grad = grads.or_insert(arg2)?; + *sum_grad = sum_grad.add(&arg_grad2)? + } + if let Some(arg_grad3) = arg_grad3 { + let sum_grad = grads.or_insert(arg3)?; + *sum_grad = sum_grad.add(&arg_grad3)? + } } Op::Unary(arg, UnaryOp::Sqr) => { let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 84fd12b1..a33dd226 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -94,6 +94,8 @@ pub(crate) enum Op { Transpose(Tensor, usize, usize), Elu(Tensor, f64), CustomOp1(Tensor, std::sync::Arc>), + CustomOp2(Tensor, Tensor, std::sync::Arc>), + CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc>), } /// Unary ops that can be defined in user-land. @@ -116,7 +118,88 @@ pub trait CustomOp1: Send + Sync { /// This function takes as argument the argument `arg` used in the forward pass, the result /// produced by the forward operation `res` and the gradient of the result `grad_res`. /// The function should return the gradient of the argument. - fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result { + fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result> { + Err(crate::Error::BackwardNotSupported { op: self.name() }) + } +} + +pub trait CustomOp2: Send + Sync { + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape)>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd( + &self, + _: &CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + ) -> Result<(CudaStorage, Shape)> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + fn bwd( + &self, + _arg1: &Tensor, + _arg2: &Tensor, + _res: &Tensor, + _grad_res: &Tensor, + ) -> Result<(Option, Option)> { + Err(crate::Error::BackwardNotSupported { op: self.name() }) + } +} + +pub trait CustomOp3: Send + Sync { + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd( + &self, + _: &CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + ) -> Result<(CudaStorage, Shape)> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + fn bwd( + &self, + _arg1: &Tensor, + _arg2: &Tensor, + _arg3: &Tensor, + _res: &Tensor, + _grad_res: &Tensor, + ) -> Result<(Option, Option, Option)> { Err(crate::Error::BackwardNotSupported { op: self.name() }) } } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 752af24b..2df21862 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,5 +1,5 @@ use crate::backend::BackendStorage; -use crate::op::{self, CmpOp, CustomOp1, ReduceOp}; +use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp}; use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape}; // We do not want to implement Clone on Storage as cloning may fail because of @@ -149,7 +149,7 @@ impl Storage { pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> { match self { - Storage::Cpu(storage) => { + Self::Cpu(storage) => { let (storage, shape) = c.cpu_fwd(storage, l)?; Ok((Self::Cpu(storage), shape)) } @@ -160,6 +160,51 @@ impl Storage { } } + pub(crate) fn custom_op2( + &self, + l1: &Layout, + t2: &Self, + l2: &Layout, + c: &dyn CustomOp2, + ) -> Result<(Self, Shape)> { + self.same_device(t2, c.name())?; + match (self, t2) { + (Self::Cpu(s1), Self::Cpu(s2)) => { + let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?; + Ok((Self::Cpu(s), shape)) + } + (Self::Cuda(s1), Self::Cuda(s2)) => { + let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?; + Ok((Self::Cuda(s), shape)) + } + _ => unreachable!(), + } + } + + pub(crate) fn custom_op3( + &self, + l1: &Layout, + t2: &Self, + l2: &Layout, + t3: &Self, + l3: &Layout, + c: &dyn CustomOp3, + ) -> Result<(Self, Shape)> { + self.same_device(t2, c.name())?; + self.same_device(t3, c.name())?; + match (self, t2, t3) { + (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => { + let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?; + Ok((Self::Cpu(s), shape)) + } + (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => { + let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?; + Ok((Self::Cuda(s), shape)) + } + _ => unreachable!(), + } + } + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { // TODO: Different code path for the contiguous case? match self { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 84329a2f..e4e4ba6b 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,5 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::op::{BinaryOp, CmpOp, CustomOp1, Op, ReduceOp, UnaryOp}; +use crate::op::{BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp}; use crate::shape::{Dim, Dims}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; @@ -1705,6 +1705,48 @@ impl Tensor { pub fn custom_op1(&self, c: C) -> Result { self.custom_op1_arc(Arc::new(Box::new(c))) } + + /// Applies a binary custom op. + pub fn custom_op2_arc(&self, rhs: &Self, c: Arc>) -> Result { + let (storage, shape) = self.storage().custom_op2( + self.layout(), + &rhs.storage(), + rhs.layout(), + c.as_ref().as_ref(), + )?; + let op = if self.track_op() { + Some(Op::CustomOp2(self.clone(), rhs.clone(), c)) + } else { + None + }; + Ok(from_storage(storage, shape, op, false)) + } + + pub fn custom_op2(&self, r: &Self, c: C) -> Result { + self.custom_op2_arc(r, Arc::new(Box::new(c))) + } + + /// Applies a ternary custom op. + pub fn custom_op3_arc(&self, t2: &Self, t3: &Self, c: Arc>) -> Result { + let (storage, shape) = self.storage().custom_op3( + self.layout(), + &t2.storage(), + t2.layout(), + &t3.storage(), + t3.layout(), + c.as_ref().as_ref(), + )?; + let op = if self.track_op() { + Some(Op::CustomOp3(self.clone(), t2.clone(), t3.clone(), c)) + } else { + None + }; + Ok(from_storage(storage, shape, op, false)) + } + + pub fn custom_op3(&self, t2: &Self, t3: &Self, c: C) -> Result { + self.custom_op3_arc(t2, t3, Arc::new(Box::new(c))) + } } macro_rules! bin_trait { diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs index 3ce125bc..5c95dc82 100644 --- a/candle-core/tests/custom_op_tests.rs +++ b/candle-core/tests/custom_op_tests.rs @@ -94,10 +94,10 @@ impl CustomOp1 for EluWithBackward { self.0.cpu_fwd(s, l) } - fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result { + fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result> { let alpha = self.0.alpha; let bwd = arg.custom_op1(EluBackward { alpha })?; - grad_res.mul(&bwd) + Ok(Some(grad_res.mul(&bwd)?)) } }