Add max-all/min-all. (#2616)

This commit is contained in:
Laurent Mazare
2024-11-14 21:08:04 +01:00
committed by GitHub
parent 06350c31c7
commit 0ed24b9852

View File

@ -1760,6 +1760,42 @@ impl Tensor {
&self.op
}
/// Computes the max of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.max_all()?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn max_all(&self) -> Result<Tensor> {
if self.rank() == 0 {
Ok(self.clone())
} else {
self.flatten_all()?.max(0)
}
}
/// Computes the min of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.min_all()?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn min_all(&self) -> Result<Tensor> {
if self.rank() == 0 {
Ok(self.clone())
} else {
self.flatten_all()?.min(0)
}
}
/// Computes the sum of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///