diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index e4f556ab..08e2f628 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -7,7 +7,7 @@ //! # Example //! //! ```rust -//! use candle::{Tensor, Device::Cpu}; +//! use candle::{Tensor, Device::Cpu, test_utils::to_vec3_round}; //! use candle_nn::{LayerNorm, Module}; //! # fn main() -> candle::Result<()> { //! @@ -20,10 +20,10 @@ //! &Cpu)?; //! let ys = layer.forward(&xs)?; //! assert_eq!( -//! ys.to_vec3::()?, -//! &[[[-1.2247356, 0.0, 1.2247356], -//! [-1.2247356, 0.0, 1.2247356], -//! [ 1.2247356, 0.0, -1.2247356]]]); +//! to_vec3_round(&ys, 4)?, +//! &[[[-1.2247, 0.0, 1.2247], +//! [-1.2247, 0.0, 1.2247], +//! [ 1.2247, 0.0, -1.2247]]]); //! # Ok(()) } //! ``` //! diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 63f73dfe..c3b6ffa2 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -4,14 +4,14 @@ use candle::{Result, Tensor}; /// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. /// /// ```rust -/// use candle::{Tensor, Device}; +/// use candle::{Tensor, Device, test_utils::to_vec2_round}; /// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?; /// let a = candle_nn::ops::softmax(&a, 1)?; /// assert_eq!( -/// a.to_vec2::()?, +/// to_vec2_round(&a, 4)?, /// &[ -/// [0.13447072, 0.3655293, 0.13447072, 0.3655293], -/// [0.0048928666, 0.26714146, 0.7261658, 0.0017999851] +/// [0.1345, 0.3655, 0.1345, 0.3655], +/// [0.0049, 0.2671, 0.7262, 0.0018] /// ]); /// # Ok::<(), candle::Error>(()) /// ```