Fix the tests for mkl. (#437)

This commit is contained in:
Laurent Mazare
2023-08-14 08:09:27 +01:00
committed by GitHub
parent 9e7e6e0288
commit eab54e4490
2 changed files with 15 additions and 10 deletions

View File

@ -1,4 +1,6 @@
use candle::{Device, Result, Tensor};
mod test_utils;
use test_utils::to_vec0_round;
/* Equivalent python code:
import torch
@ -27,8 +29,8 @@ fn nll_and_cross_entropy() -> Result<()> {
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
let loss = candle_nn::loss::cross_entropy(&input, &target)?;
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
Ok(())
}