From 0ed24b9852ccc7dfb92d555afba3d56c2a3f3224 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 14 Nov 2024 21:08:04 +0100 Subject: [PATCH] Add max-all/min-all. (#2616) --- candle-core/src/tensor.rs | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e7355aad..75dc1c8a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1760,6 +1760,42 @@ impl Tensor { &self.op } + /// Computes the max of all the elements in this tensor and returns a tensor holding this + /// scalar with zero dimensions. + /// + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.max_all()?; + /// assert_eq!(tensor.to_scalar::()?, 5.); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn max_all(&self) -> Result { + if self.rank() == 0 { + Ok(self.clone()) + } else { + self.flatten_all()?.max(0) + } + } + + /// Computes the min of all the elements in this tensor and returns a tensor holding this + /// scalar with zero dimensions. + /// + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.min_all()?; + /// assert_eq!(tensor.to_scalar::()?, 0.); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn min_all(&self) -> Result { + if self.rank() == 0 { + Ok(self.clone()) + } else { + self.flatten_all()?.min(0) + } + } + /// Computes the sum of all the elements in this tensor and returns a tensor holding this /// scalar with zero dimensions. ///