Add the cross-entropy loss. (#287)

This commit is contained in:
Laurent Mazare
2023-07-31 14:26:36 +01:00
committed by GitHub
parent ffeafbfc43
commit 1064b9b031
2 changed files with 21 additions and 1 deletions

View File

@ -26,3 +26,20 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
.sum_all()? .sum_all()?
.affine(-1f64 / b_sz as f64, 0.) .affine(-1f64 / b_sz as f64, 0.)
} }
/// The cross-entropy loss.
///
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
///
/// The resulting tensor is a scalar containing the average value over the batch.
pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
if inp.rank() != 2 {
candle::bail!("cross_entropy expects an input tensor of rank 2")
}
let inp = crate::ops::log_softmax(inp, 1)?;
nll(&inp, target)
}

View File

@ -10,9 +10,10 @@ input = torch.tensor([
target = torch.tensor([1, 0, 4]) target = torch.tensor([1, 0, 4])
print(F.nll_loss(F.log_softmax(input, dim=1), target)) print(F.nll_loss(F.log_softmax(input, dim=1), target))
print(F.cross_entropy(input, target))
*/ */
#[test] #[test]
fn nll() -> Result<()> { fn nll_and_cross_entropy() -> Result<()> {
let cpu = Device::Cpu; let cpu = Device::Cpu;
let input = Tensor::new( let input = Tensor::new(
&[ &[
@ -27,5 +28,7 @@ fn nll() -> Result<()> {
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?; let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
let loss = candle_nn::loss::nll(&log_softmax, &target)?; let loss = candle_nn::loss::nll(&log_softmax, &target)?;
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335); assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
let loss = candle_nn::loss::cross_entropy(&input, &target)?;
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
Ok(()) Ok(())
} }