mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
Fix the tests for mkl. (#437)
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
mod test_utils;
|
||||
use test_utils::to_vec0_round;
|
||||
|
||||
/* Equivalent python code:
|
||||
import torch
|
||||
@ -27,8 +29,8 @@ fn nll_and_cross_entropy() -> Result<()> {
|
||||
|
||||
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
|
||||
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
|
||||
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
||||
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
|
||||
let loss = candle_nn::loss::cross_entropy(&input, &target)?;
|
||||
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
||||
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user