diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 6108b613..0f48dc62 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2088,6 +2088,22 @@ macro_rules! bin_trait { } } + impl> std::ops::$trait for Result { + type Output = Result; + + fn $fn1(self, rhs: Tensor) -> Self::Output { + Tensor::$fn1(self?.borrow(), &rhs) + } + } + + impl> std::ops::$trait<&Tensor> for Result { + type Output = Result; + + fn $fn1(self, rhs: &Tensor) -> Self::Output { + Tensor::$fn1(self?.borrow(), rhs) + } + } + impl> std::ops::$trait> for Tensor { type Output = Result;