Files
candle/candle-nn/tests/loss.rs
2023-08-27 09:42:22 +01:00

42 lines
1.2 KiB
Rust

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::test_utils::to_vec0_round;
use candle::{Device, Result, Tensor};
/* Equivalent python code:
import torch
import torch.nn.functional as F
input = torch.tensor([
[ 1.1050, 0.3013, -1.5394, -2.1528, -0.8634],
[ 1.0730, -0.9419, -0.1670, -0.6582, 0.5061],
[ 0.8318, 1.1154, -0.3610, 0.5351, 1.0830]])
target = torch.tensor([1, 0, 4])
print(F.nll_loss(F.log_softmax(input, dim=1), target))
print(F.cross_entropy(input, target))
*/
#[test]
fn nll_and_cross_entropy() -> Result<()> {
let cpu = Device::Cpu;
let input = Tensor::new(
&[
[1.1050f32, 0.3013, -1.5394, -2.1528, -0.8634],
[1.0730, -0.9419, -0.1670, -0.6582, 0.5061],
[0.8318, 1.1154, -0.3610, 0.5351, 1.0830],
],
&cpu,
)?;
let target = Tensor::new(&[1u32, 0, 4], &cpu)?;
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
let loss = candle_nn::loss::cross_entropy(&input, &target)?;
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
Ok(())
}