diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 42d660f4..05791ed1 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -482,6 +482,11 @@ impl Tensor { } } + /// An alias for `to_scalar`. + pub fn to_vec0(&self) -> Result { + self.to_scalar::() + } + /// This operation multiplies the input tensor by `mul` then adds `add` and return the result. /// The input values `mul` and `add` are casted to the appropriate type so some rounding might /// be performed. diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 6f30b5b7..591b504a 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -114,23 +114,23 @@ fn unary_grad(device: &Device) -> Result<()> { let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!( - y.to_vec1::()?, - [0.14112, 0.84147096, -0.7568025, 0.14943814], + test_utils::to_vec1_round(&y, 4)?, + [0.1411, 0.8415, -0.7568, 0.1494], ); assert_eq!( - grad_x.to_vec1::()?, - [-0.9899925, 0.5403023, -0.6536436, 0.9887711], + test_utils::to_vec1_round(grad_x, 4)?, + [-0.99, 0.5403, -0.6536, 0.9888], ); let y = x.cos()?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!( - y.to_vec1::()?, - [-0.9899925, 0.5403023, -0.6536436, 0.9887711], + test_utils::to_vec1_round(&y, 4)?, + [-0.99, 0.5403, -0.6536, 0.9888], ); assert_eq!( - grad_x.to_vec1::()?, - [-0.14112, -0.84147096, 0.7568025, -0.14943814], + test_utils::to_vec1_round(grad_x, 4)?, + [-0.1411, -0.8415, 0.7568, -0.1494], ); let y = x.sqr()?; let grads = y.backward()?; diff --git a/candle-core/tests/test_utils.rs b/candle-core/tests/test_utils.rs index 4dd44b64..5f7d3117 100644 --- a/candle-core/tests/test_utils.rs +++ b/candle-core/tests/test_utils.rs @@ -20,6 +20,23 @@ macro_rules! test_device { }; } +pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result> { + let b = 10f32.powi(digits); + let t = t.to_vec1::()?; + let t = t.iter().map(|t| f32::round(t * b) / b).collect(); + Ok(t) +} + +pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result>> { + let b = 10f32.powi(digits); + let t = t.to_vec2::()?; + let t = t + .iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect(); + Ok(t) +} + pub fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { let b = 10f32.powi(digits); let t = t.to_vec3::()?;