Add the round-to function. (#1039)

This commit is contained in:
Laurent Mazare
2023-10-05 20:28:09 +01:00
committed by GitHub
parent f47bd9bab5
commit 7f7d95e2c3
2 changed files with 18 additions and 0 deletions

View File

@ -496,6 +496,15 @@ impl Tensor {
unary_op!(floor, Floor);
unary_op!(round, Round);
/// Round element of the input tensor to the nearest integer.
///
/// If the number of decimals is negative, it specifies the number of positions to the left of
/// the decimal point.
pub fn round_to(&self, decimals: i32) -> Result<Self> {
let mult = 10f64.powi(decimals);
(self * mult)?.round()? * (1f64 / mult)
}
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
/// dimensions, an error is returned instead.
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {

View File

@ -105,6 +105,15 @@ fn unary_op(device: &Device) -> Result<()> {
test_utils::to_vec2_round(&tensor.round()?, 4)?,
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]]
);
let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?;
assert_eq!(
test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?,
[2997.92, 314.16]
);
assert_eq!(
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
[3000.0, 300.]
);
Ok(())
}