Add the cross-entropy loss. (#287)

This commit is contained in:
Laurent Mazare
2023-07-31 14:26:36 +01:00
committed by GitHub
parent ffeafbfc43
commit 1064b9b031
2 changed files with 21 additions and 1 deletions

View File

@ -10,9 +10,10 @@ input = torch.tensor([
target = torch.tensor([1, 0, 4])
print(F.nll_loss(F.log_softmax(input, dim=1), target))
print(F.cross_entropy(input, target))
*/
#[test]
fn nll() -> Result<()> {
fn nll_and_cross_entropy() -> Result<()> {
let cpu = Device::Cpu;
let input = Tensor::new(
&[
@ -27,5 +28,7 @@ fn nll() -> Result<()> {
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
let loss = candle_nn::loss::cross_entropy(&input, &target)?;
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
Ok(())
}