More robust tests (so that they pass on accelerate). (#679)

This commit is contained in:
Laurent Mazare
2023-08-30 19:10:10 +02:00
committed by GitHub
parent 9874d843f1
commit 2047d34b7c
2 changed files with 9 additions and 9 deletions

View File

@ -7,7 +7,7 @@
//! # Example //! # Example
//! //!
//! ```rust //! ```rust
//! use candle::{Tensor, Device::Cpu}; //! use candle::{Tensor, Device::Cpu, test_utils::to_vec3_round};
//! use candle_nn::{LayerNorm, Module}; //! use candle_nn::{LayerNorm, Module};
//! # fn main() -> candle::Result<()> { //! # fn main() -> candle::Result<()> {
//! //!
@ -20,10 +20,10 @@
//! &Cpu)?; //! &Cpu)?;
//! let ys = layer.forward(&xs)?; //! let ys = layer.forward(&xs)?;
//! assert_eq!( //! assert_eq!(
//! ys.to_vec3::<f32>()?, //! to_vec3_round(&ys, 4)?,
//! &[[[-1.2247356, 0.0, 1.2247356], //! &[[[-1.2247, 0.0, 1.2247],
//! [-1.2247356, 0.0, 1.2247356], //! [-1.2247, 0.0, 1.2247],
//! [ 1.2247356, 0.0, -1.2247356]]]); //! [ 1.2247, 0.0, -1.2247]]]);
//! # Ok(()) } //! # Ok(()) }
//! ``` //! ```
//! //!

View File

@ -4,14 +4,14 @@ use candle::{Result, Tensor};
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. /// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
/// ///
/// ```rust /// ```rust
/// use candle::{Tensor, Device}; /// use candle::{Tensor, Device, test_utils::to_vec2_round};
/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?; /// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
/// let a = candle_nn::ops::softmax(&a, 1)?; /// let a = candle_nn::ops::softmax(&a, 1)?;
/// assert_eq!( /// assert_eq!(
/// a.to_vec2::<f32>()?, /// to_vec2_round(&a, 4)?,
/// &[ /// &[
/// [0.13447072, 0.3655293, 0.13447072, 0.3655293], /// [0.1345, 0.3655, 0.1345, 0.3655],
/// [0.0048928666, 0.26714146, 0.7261658, 0.0017999851] /// [0.0049, 0.2671, 0.7262, 0.0018]
/// ]); /// ]);
/// # Ok::<(), candle::Error>(()) /// # Ok::<(), candle::Error>(())
/// ``` /// ```