mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add max-all/min-all. (#2616)
This commit is contained in:
@ -1760,6 +1760,42 @@ impl Tensor {
|
||||
&self.op
|
||||
}
|
||||
|
||||
/// Computes the max of all the elements in this tensor and returns a tensor holding this
|
||||
/// scalar with zero dimensions.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.max_all()?;
|
||||
/// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn max_all(&self) -> Result<Tensor> {
|
||||
if self.rank() == 0 {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
self.flatten_all()?.max(0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the min of all the elements in this tensor and returns a tensor holding this
|
||||
/// scalar with zero dimensions.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.min_all()?;
|
||||
/// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn min_all(&self) -> Result<Tensor> {
|
||||
if self.rank() == 0 {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
self.flatten_all()?.min(0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the sum of all the elements in this tensor and returns a tensor holding this
|
||||
/// scalar with zero dimensions.
|
||||
///
|
||||
|
Reference in New Issue
Block a user