diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 2b57f7f7..cf99f86e 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -118,13 +118,22 @@ pub enum Op { ToDevice(Tensor), Transpose(Tensor, usize, usize), Elu(Tensor, f64), - CustomOp1(Tensor, std::sync::Arc>), - CustomOp2(Tensor, Tensor, std::sync::Arc>), - CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc>), + CustomOp1(Tensor, std::sync::Arc>), + CustomOp2( + Tensor, + Tensor, + std::sync::Arc>, + ), + CustomOp3( + Tensor, + Tensor, + Tensor, + std::sync::Arc>, + ), } /// Unary ops that can be defined in user-land. -pub trait CustomOp1: Send + Sync { +pub trait CustomOp1 { // Box does not support const yet, so use a function to get the name. fn name(&self) -> &'static str; @@ -148,7 +157,7 @@ pub trait CustomOp1: Send + Sync { } } -pub trait CustomOp2: Send + Sync { +pub trait CustomOp2 { fn name(&self) -> &'static str; /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, @@ -186,7 +195,7 @@ pub trait CustomOp2: Send + Sync { } } -pub trait CustomOp3: Send + Sync { +pub trait CustomOp3 { fn name(&self) -> &'static str; /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index a0ed5b4d..a334b2c1 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -147,11 +147,11 @@ impl QTensor { } } -pub struct QMatMul(std::sync::Arc>); +pub struct QMatMul(QTensor); impl QMatMul { pub fn from_qtensor(qtensor: QTensor) -> Self { - Self(std::sync::Arc::new(Box::new(qtensor))) + Self(qtensor) } } @@ -196,6 +196,6 @@ impl crate::CustomOp1 for QTensor { impl QMatMul { pub fn forward(&self, xs: &Tensor) -> Result { - xs.custom_op1_arc(self.0.clone()) + xs.apply_op1_no_bwd(&self.0) } } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 791b65dd..4a6cdc34 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -138,7 +138,7 @@ impl Storage { } } - pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> { + pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> { match self { Self::Cpu(storage) => { let (storage, shape) = c.cpu_fwd(storage, l)?; @@ -151,7 +151,7 @@ impl Storage { } } - pub(crate) fn custom_op2( + pub(crate) fn apply_op2( &self, l1: &Layout, t2: &Self, @@ -172,7 +172,7 @@ impl Storage { } } - pub(crate) fn custom_op3( + pub(crate) fn apply_op3( &self, l1: &Layout, t2: &Self, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c14a4e39..c71ea5ec 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1870,22 +1870,53 @@ impl Tensor { std::ptr::eq(lhs, rhs) } + /// Applies a unary custom op without backward support + pub fn apply_op1_no_bwd(&self, c: &C) -> Result { + let (storage, shape) = self.storage().apply_op1(self.layout(), c)?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + + /// Applies a binary custom op without backward support + pub fn apply_op2_no_bwd(&self, rhs: &Self, c: &C) -> Result { + let (storage, shape) = + self.storage() + .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + + /// Applies a ternary custom op without backward support + pub fn apply_op3_no_bwd(&self, t2: &Self, t3: &Self, c: &C) -> Result { + let (storage, shape) = self.storage().apply_op3( + self.layout(), + &t2.storage(), + t2.layout(), + &t3.storage(), + t3.layout(), + c, + )?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + /// Applies a unary custom op. - pub fn custom_op1_arc(&self, c: Arc>) -> Result { + pub fn apply_op1_arc(&self, c: Arc>) -> Result { let (storage, shape) = self .storage() - .custom_op1(self.layout(), c.as_ref().as_ref())?; + .apply_op1(self.layout(), c.as_ref().as_ref())?; let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone())); Ok(from_storage(storage, shape, op, false)) } - pub fn custom_op1(&self, c: C) -> Result { - self.custom_op1_arc(Arc::new(Box::new(c))) + pub fn apply_op1(&self, c: C) -> Result { + self.apply_op1_arc(Arc::new(Box::new(c))) } /// Applies a binary custom op. - pub fn custom_op2_arc(&self, rhs: &Self, c: Arc>) -> Result { - let (storage, shape) = self.storage().custom_op2( + pub fn apply_op2_arc( + &self, + rhs: &Self, + c: Arc>, + ) -> Result { + let (storage, shape) = self.storage().apply_op2( self.layout(), &rhs.storage(), rhs.layout(), @@ -1895,13 +1926,18 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } - pub fn custom_op2(&self, r: &Self, c: C) -> Result { - self.custom_op2_arc(r, Arc::new(Box::new(c))) + pub fn apply_op2(&self, r: &Self, c: C) -> Result { + self.apply_op2_arc(r, Arc::new(Box::new(c))) } /// Applies a ternary custom op. - pub fn custom_op3_arc(&self, t2: &Self, t3: &Self, c: Arc>) -> Result { - let (storage, shape) = self.storage().custom_op3( + pub fn apply_op3_arc( + &self, + t2: &Self, + t3: &Self, + c: Arc>, + ) -> Result { + let (storage, shape) = self.storage().apply_op3( self.layout(), &t2.storage(), t2.layout(), @@ -1915,8 +1951,13 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } - pub fn custom_op3(&self, t2: &Self, t3: &Self, c: C) -> Result { - self.custom_op3_arc(t2, t3, Arc::new(Box::new(c))) + pub fn apply_op3( + &self, + t2: &Self, + t3: &Self, + c: C, + ) -> Result { + self.apply_op3_arc(t2, t3, Arc::new(Box::new(c))) } } diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs index 55b5e894..7ec04c6a 100644 --- a/candle-core/tests/custom_op_tests.rs +++ b/candle-core/tests/custom_op_tests.rs @@ -39,7 +39,7 @@ fn custom_op1_no_backward() -> Result<()> { let cpu = &Device::Cpu; let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?; let t = (t - 5.)?; - let elu_t = t.custom_op1(Elu { alpha: 1. })?; + let elu_t = t.apply_op1_no_bwd(&Elu { alpha: 1. })?; assert_eq!( to_vec1_round(&elu_t, 4)?, &[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] @@ -96,7 +96,7 @@ impl CustomOp1 for EluWithBackward { fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result> { let alpha = self.0.alpha; - let bwd = arg.custom_op1(EluBackward { alpha })?; + let bwd = arg.apply_op1(EluBackward { alpha })?; Ok(Some(grad_res.mul(&bwd)?)) } } @@ -105,7 +105,7 @@ impl CustomOp1 for EluWithBackward { fn custom_op1_with_backward() -> Result<()> { let cpu = &Device::Cpu; let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?; - let elu_t = t.custom_op1(EluWithBackward::new(2.))?; + let elu_t = t.apply_op1(EluWithBackward::new(2.))?; assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]); let grads = elu_t.backward()?; diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs index 63bcd83a..7f7a3f26 100644 --- a/candle-examples/examples/custom-ops/main.rs +++ b/candle-examples/examples/custom-ops/main.rs @@ -89,7 +89,7 @@ fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?; println!("{t}"); - let t = t.custom_op1(LayerNorm { eps: 1e-5 })?; + let t = t.apply_op1(LayerNorm { eps: 1e-5 })?; println!("{t}"); Ok(()) } diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index ab4e382c..ad5e4cd2 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -68,7 +68,7 @@ impl CustomOp1 for AllReduce { } fn all_reduce_sum(x: &Tensor, comm: &Rc) -> Result { - x.custom_op1(AllReduce { comm: comm.clone() }) + x.apply_op1(AllReduce { comm: comm.clone() }) } impl TensorParallelRowLinear { diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 092743f1..3c5fd455 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -178,7 +178,7 @@ pub fn flash_attn( softmax_scale, causal, }; - q.custom_op3(k, v, op) + q.apply_op3(k, v, op) } struct FlashAttnVarLen { @@ -402,5 +402,5 @@ pub fn flash_attn_varlen( seqlens_q: seqlens_q.clone(), seqlens_k: seqlens_k.clone(), }; - q.custom_op3(k, v, op) + q.apply_op3(k, v, op) }