mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Make the nll op closer to the pytorch version + add a test. (#286)
This commit is contained in:
@ -1,8 +1,28 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
/// The negative loss likelihodd loss.
|
||||
///
|
||||
/// Arguments
|
||||
///
|
||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||
/// of categories. This is expected to contain log probabilities.
|
||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
||||
///
|
||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||
pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
let b_sz = target.dim(0)?;
|
||||
inp.gather(target, 1)?
|
||||
let b_sz = match target.dims() {
|
||||
&[b_sz] => b_sz,
|
||||
dims => candle::bail!("the target tensor should have a single dimension ({dims:?})"),
|
||||
};
|
||||
match inp.dims() {
|
||||
&[inp_b_sz, _] => {
|
||||
if inp_b_sz != b_sz {
|
||||
candle::bail!("batch size mismatch between inp ({inp_b_sz}) and target ({b_sz})")
|
||||
}
|
||||
}
|
||||
dims => candle::bail!("the target tensor should have two dimensions ({dims:?})"),
|
||||
}
|
||||
inp.gather(&target.unsqueeze(1)?, 1)?
|
||||
.sum_all()?
|
||||
.affine(-1f64 / b_sz as f64, 0.)
|
||||
}
|
||||
|
31
candle-nn/tests/loss.rs
Normal file
31
candle-nn/tests/loss.rs
Normal file
@ -0,0 +1,31 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
/* Equivalent python code:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
input = torch.tensor([
|
||||
[ 1.1050, 0.3013, -1.5394, -2.1528, -0.8634],
|
||||
[ 1.0730, -0.9419, -0.1670, -0.6582, 0.5061],
|
||||
[ 0.8318, 1.1154, -0.3610, 0.5351, 1.0830]])
|
||||
|
||||
target = torch.tensor([1, 0, 4])
|
||||
print(F.nll_loss(F.log_softmax(input, dim=1), target))
|
||||
*/
|
||||
#[test]
|
||||
fn nll() -> Result<()> {
|
||||
let cpu = Device::Cpu;
|
||||
let input = Tensor::new(
|
||||
&[
|
||||
[1.1050f32, 0.3013, -1.5394, -2.1528, -0.8634],
|
||||
[1.0730, -0.9419, -0.1670, -0.6582, 0.5061],
|
||||
[0.8318, 1.1154, -0.3610, 0.5351, 1.0830],
|
||||
],
|
||||
&cpu,
|
||||
)?;
|
||||
let target = Tensor::new(&[1u32, 0, 4], &cpu)?;
|
||||
|
||||
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
|
||||
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
|
||||
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user