diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index 937510c7..5bc2e99b 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -178,10 +178,7 @@ fn training_loop( let train_labels = m.train_labels; let train_images = m.train_images.to_device(&dev)?; - let train_labels = train_labels - .to_dtype(DType::U32)? - .unsqueeze(1)? - .to_device(&dev)?; + let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?; let mut vs = VarStore::new(DType::F32, dev.clone()); let model = M::new(vs.clone())?; diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index d388af6c..b380c426 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -1,8 +1,28 @@ use candle::{Result, Tensor}; +/// The negative loss likelihodd loss. +/// +/// Arguments +/// +/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number +/// of categories. This is expected to contain log probabilities. +/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`. +/// +/// The resulting tensor is a scalar containing the average value over the batch. pub fn nll(inp: &Tensor, target: &Tensor) -> Result { - let b_sz = target.dim(0)?; - inp.gather(target, 1)? + let b_sz = match target.dims() { + &[b_sz] => b_sz, + dims => candle::bail!("the target tensor should have a single dimension ({dims:?})"), + }; + match inp.dims() { + &[inp_b_sz, _] => { + if inp_b_sz != b_sz { + candle::bail!("batch size mismatch between inp ({inp_b_sz}) and target ({b_sz})") + } + } + dims => candle::bail!("the target tensor should have two dimensions ({dims:?})"), + } + inp.gather(&target.unsqueeze(1)?, 1)? .sum_all()? .affine(-1f64 / b_sz as f64, 0.) } diff --git a/candle-nn/tests/loss.rs b/candle-nn/tests/loss.rs new file mode 100644 index 00000000..12057191 --- /dev/null +++ b/candle-nn/tests/loss.rs @@ -0,0 +1,31 @@ +use candle::{Device, Result, Tensor}; + +/* Equivalent python code: +import torch +import torch.nn.functional as F +input = torch.tensor([ + [ 1.1050, 0.3013, -1.5394, -2.1528, -0.8634], + [ 1.0730, -0.9419, -0.1670, -0.6582, 0.5061], + [ 0.8318, 1.1154, -0.3610, 0.5351, 1.0830]]) + +target = torch.tensor([1, 0, 4]) +print(F.nll_loss(F.log_softmax(input, dim=1), target)) +*/ +#[test] +fn nll() -> Result<()> { + let cpu = Device::Cpu; + let input = Tensor::new( + &[ + [1.1050f32, 0.3013, -1.5394, -2.1528, -0.8634], + [1.0730, -0.9419, -0.1670, -0.6582, 0.5061], + [0.8318, 1.1154, -0.3610, 0.5351, 1.0830], + ], + &cpu, + )?; + let target = Tensor::new(&[1u32, 0, 4], &cpu)?; + + let log_softmax = candle_nn::ops::log_softmax(&input, 1)?; + let loss = candle_nn::loss::nll(&log_softmax, &target)?; + assert_eq!(loss.to_vec0::()?, 1.1312335); + Ok(()) +}