diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c932cd51..7295c350 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -496,6 +496,15 @@ impl Tensor { unary_op!(floor, Floor); unary_op!(round, Round); + /// Round element of the input tensor to the nearest integer. + /// + /// If the number of decimals is negative, it specifies the number of positions to the left of + /// the decimal point. + pub fn round_to(&self, decimals: i32) -> Result { + let mult = 10f64.powi(decimals); + (self * mult)?.round()? * (1f64 / mult) + } + /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple /// dimensions, an error is returned instead. pub fn to_scalar(&self) -> Result { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index a50f3a6c..c938ffea 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -105,6 +105,15 @@ fn unary_op(device: &Device) -> Result<()> { test_utils::to_vec2_round(&tensor.round()?, 4)?, [[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]] ); + let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?; + assert_eq!( + test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?, + [2997.92, 314.16] + ); + assert_eq!( + test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?, + [3000.0, 300.] + ); Ok(()) }