diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index a4995a60..1f5f45ab 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -1,7 +1,7 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; -pub(crate) trait BackendStorage: Sized { +pub trait BackendStorage: Sized { type Device: BackendDevice; fn try_clone(&self, _: &Layout) -> Result; @@ -53,7 +53,7 @@ pub(crate) trait BackendStorage: Sized { fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>; } -pub(crate) trait BackendDevice: Sized + std::fmt::Debug + Clone { +pub trait BackendDevice: Sized + std::fmt::Debug + Clone { type Storage: BackendStorage; // TODO: Make the usize generic and part of a generic DeviceLocation. diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 62cbc488..9ae6c23c 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -86,7 +86,8 @@ impl Tensor { | Op::Narrow(node, _, _, _) | Op::Softmax(node, _) | Op::Unary(node, _) - | Op::Elu(node, _) => { + | Op::Elu(node, _) + | Op::CustomOp1(node, _) => { let (tg, nodes) = walk(node, nodes, already_seen); track_grad |= tg; nodes @@ -319,6 +320,11 @@ impl Tensor { Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?, Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?, Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?, + Op::CustomOp1(arg, c) => { + let sum_grad = grads.or_insert(arg)?; + let arg_grad = c.bwd(arg, node, &grad)?; + *sum_grad = sum_grad.add(&arg_grad)? + } Op::Unary(arg, UnaryOp::Sqr) => { let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; let sum_grad = grads.or_insert(arg)?; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 7901a7da..d529b173 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -18,7 +18,7 @@ pub enum CpuStorage { #[derive(Debug, Clone)] pub struct CpuDevice; -trait Map1 { +pub trait Map1 { fn f(&self, vs: &[T], layout: &Layout) -> Result>; fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result { @@ -33,7 +33,7 @@ trait Map1 { } } -trait Map1Any { +pub trait Map1Any { fn f) -> CpuStorage>( &self, vs: &[T], @@ -54,7 +54,7 @@ trait Map1Any { } type C = CpuStorage; -trait Map2 { +pub trait Map2 { const OP: &'static str; fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; @@ -82,7 +82,7 @@ trait Map2 { } } -trait Map2U8 { +pub trait Map2U8 { const OP: &'static str; fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; @@ -348,7 +348,11 @@ impl<'a> Map1 for Reduce<'a> { } } -fn unary_map U>(vs: &[T], layout: &Layout, mut f: F) -> Vec { +pub fn unary_map U>( + vs: &[T], + layout: &Layout, + mut f: F, +) -> Vec { match layout.strided_blocks() { crate::StridedBlocks::SingleBlock { start_offset, len } => vs [start_offset..start_offset + len] @@ -380,7 +384,7 @@ fn unary_map U>(vs: &[T], layout: &Layout, mut } } -fn unary_map_vec U, FV: FnMut(&[T], &mut [U])>( +pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U])>( vs: &[T], layout: &Layout, mut f: F, diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 7760e2c7..5a35955f 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -33,13 +33,13 @@ //! //! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers) -mod backend; +pub mod backend; mod backprop; mod conv; mod convert; -mod cpu_backend; +pub mod cpu_backend; #[cfg(feature = "cuda")] -mod cuda_backend; +pub mod cuda_backend; mod device; pub mod display; mod dtype; @@ -65,6 +65,7 @@ pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; pub use indexer::IndexOp; pub use layout::Layout; +pub use op::CustomOp1; pub use shape::{Shape, D}; pub use storage::Storage; pub use strided_index::{StridedBlocks, StridedIndex}; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 226cff41..84fd12b1 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -1,4 +1,4 @@ -use crate::Tensor; +use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor}; use half::{bf16, f16}; use num_traits::float::Float; @@ -93,10 +93,35 @@ pub(crate) enum Op { ToDevice(Tensor), Transpose(Tensor, usize, usize), Elu(Tensor, f64), - // TODO: Support for custom ops. + CustomOp1(Tensor, std::sync::Arc>), } -pub(crate) trait UnaryOpT { +/// Unary ops that can be defined in user-land. +pub trait CustomOp1: Send + Sync { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd(&self, _: &CudaStorage, _: &Layout) -> Result<(CudaStorage, Shape)> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// This function takes as argument the argument `arg` used in the forward pass, the result + /// produced by the forward operation `res` and the gradient of the result `grad_res`. + /// The function should return the gradient of the argument. + fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result { + Err(crate::Error::BackwardNotSupported { op: self.name() }) + } +} + +pub trait UnaryOpT { const NAME: &'static str; const KERNEL: &'static str; const V: Self; @@ -119,7 +144,7 @@ pub(crate) trait UnaryOpT { fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {} } -pub(crate) trait BinaryOpT { +pub trait BinaryOpT { const NAME: &'static str; const KERNEL: &'static str; const V: Self; diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 71edf3dd..752af24b 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,5 +1,5 @@ use crate::backend::BackendStorage; -use crate::op::{self, CmpOp, ReduceOp}; +use crate::op::{self, CmpOp, CustomOp1, ReduceOp}; use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape}; // We do not want to implement Clone on Storage as cloning may fail because of @@ -147,6 +147,19 @@ impl Storage { } } + pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> { + match self { + Storage::Cpu(storage) => { + let (storage, shape) = c.cpu_fwd(storage, l)?; + Ok((Self::Cpu(storage), shape)) + } + Self::Cuda(storage) => { + let (storage, shape) = c.cuda_fwd(storage, l)?; + Ok((Self::Cuda(storage), shape)) + } + } + } + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { // TODO: Different code path for the contiguous case? match self { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 05791ed1..84329a2f 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,5 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::op::{BinaryOp, CmpOp, Op, ReduceOp, UnaryOp}; +use crate::op::{BinaryOp, CmpOp, CustomOp1, Op, ReduceOp, UnaryOp}; use crate::shape::{Dim, Dims}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; @@ -1688,6 +1688,23 @@ impl Tensor { let rhs: &RwLock = rhs.storage.as_ref(); std::ptr::eq(lhs, rhs) } + + /// Applies a unary custom op. + pub fn custom_op1_arc(&self, c: Arc>) -> Result { + let (storage, shape) = self + .storage() + .custom_op1(self.layout(), c.as_ref().as_ref())?; + let op = if self.track_op() { + Some(Op::CustomOp1(self.clone(), c)) + } else { + None + }; + Ok(from_storage(storage, shape, op, false)) + } + + pub fn custom_op1(&self, c: C) -> Result { + self.custom_op1_arc(Arc::new(Box::new(c))) + } } macro_rules! bin_trait { diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs new file mode 100644 index 00000000..3e1e0c19 --- /dev/null +++ b/candle-core/tests/custom_op_tests.rs @@ -0,0 +1,157 @@ +use candle::backend::BackendStorage; +use candle::cpu_backend; +use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor}; +use half::{bf16, f16}; + +mod test_utils; +use test_utils::to_vec1_round; + +fn fwd(v: T, alpha: T) -> T { + if v.is_sign_positive() { + v + } else { + (v.exp() - T::one()) * alpha + } +} + +struct Elu { + alpha: f64, +} + +impl CustomOp1 for Elu { + fn name(&self) -> &'static str { + "elu" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + use CpuStorage::*; + + // In this example, we pattern match over the different dtypes. Some helper functions and + // traits from the `cpu_backend` module can be used to avoid this in some common cases, see + // e.g. `Map1`. + let storage = match s { + BF16(s) => { + let alpha = bf16::from_f64(self.alpha); + let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha)); + BF16(data) + } + F16(s) => { + let alpha = f16::from_f64(self.alpha); + let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha)); + F16(data) + } + F32(s) => { + let alpha = self.alpha as f32; + let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha)); + F32(data) + } + F64(s) => { + let data = cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha)); + F64(data) + } + _ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?, + }; + Ok((storage, l.shape().clone())) + } +} + +#[test] +fn custom_op1_no_backward() -> Result<()> { + let cpu = &Device::Cpu; + let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?; + let t = (t - 5.)?; + let elu_t = t.custom_op1(Elu { alpha: 1. })?; + assert_eq!( + to_vec1_round(&elu_t, 4)?, + &[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + ); + Ok(()) +} + +// Define a similar struct as Elu but with backward support. +fn bwd(v: T, alpha: T) -> T { + if v.is_sign_positive() { + T::one() + } else { + v.exp() * alpha + } +} + +struct EluBackward { + alpha: f64, +} + +impl CustomOp1 for EluBackward { + fn name(&self) -> &'static str { + "elu-bwd" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + use CpuStorage::*; + + // In this example, we pattern match over the different dtypes. Some helper functions and + // traits from the `cpu_backend` module can be used to avoid this in some common cases, see + // e.g. `Map1`. + let storage = match s { + BF16(s) => { + let alpha = bf16::from_f64(self.alpha); + let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha)); + BF16(data) + } + F16(s) => { + let alpha = f16::from_f64(self.alpha); + let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha)); + F16(data) + } + F32(s) => { + let alpha = self.alpha as f32; + let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha)); + F32(data) + } + F64(s) => { + let data = cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha)); + F64(data) + } + _ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?, + }; + Ok((storage, l.shape().clone())) + } +} + +struct EluWithBackward(Elu); + +impl EluWithBackward { + fn new(alpha: f64) -> Self { + Self(Elu { alpha }) + } +} + +impl CustomOp1 for EluWithBackward { + fn name(&self) -> &'static str { + "elu" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + self.0.cpu_fwd(s, l) + } + + fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result { + let alpha = self.0.alpha; + let bwd = arg.custom_op1(EluBackward { alpha })?; + grad_res.mul(&bwd) + } +} + +#[test] +fn custom_op1_with_backward() -> Result<()> { + let cpu = &Device::Cpu; + let t = candle::Var::new(&[-2f32, 0f32, 2f32], cpu)?; + let elu_t = t.custom_op1(EluWithBackward::new(2.))?; + assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]); + + let grads = elu_t.backward()?; + let grad_x = grads.get(&t).unwrap(); + assert_eq!(to_vec1_round(grad_x, 4)?, [0.2707, 1.0, 1.0]); + + Ok(()) +}