From 1064b9b0314edcea6fc1f748fc883a6343db9770 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 31 Jul 2023 14:26:36 +0100 Subject: [PATCH] Add the cross-entropy loss. (#287) --- candle-nn/src/loss.rs | 17 +++++++++++++++++ candle-nn/tests/loss.rs | 5 ++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index b380c426..9d15719f 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -26,3 +26,20 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result { .sum_all()? .affine(-1f64 / b_sz as f64, 0.) } + +/// The cross-entropy loss. +/// +/// Arguments +/// +/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number +/// of categories. This is expected to raw logits. +/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`. +/// +/// The resulting tensor is a scalar containing the average value over the batch. +pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result { + if inp.rank() != 2 { + candle::bail!("cross_entropy expects an input tensor of rank 2") + } + let inp = crate::ops::log_softmax(inp, 1)?; + nll(&inp, target) +} diff --git a/candle-nn/tests/loss.rs b/candle-nn/tests/loss.rs index 12057191..0811fa39 100644 --- a/candle-nn/tests/loss.rs +++ b/candle-nn/tests/loss.rs @@ -10,9 +10,10 @@ input = torch.tensor([ target = torch.tensor([1, 0, 4]) print(F.nll_loss(F.log_softmax(input, dim=1), target)) +print(F.cross_entropy(input, target)) */ #[test] -fn nll() -> Result<()> { +fn nll_and_cross_entropy() -> Result<()> { let cpu = Device::Cpu; let input = Tensor::new( &[ @@ -27,5 +28,7 @@ fn nll() -> Result<()> { let log_softmax = candle_nn::ops::log_softmax(&input, 1)?; let loss = candle_nn::loss::nll(&log_softmax, &target)?; assert_eq!(loss.to_vec0::()?, 1.1312335); + let loss = candle_nn::loss::cross_entropy(&input, &target)?; + assert_eq!(loss.to_vec0::()?, 1.1312335); Ok(()) }