mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add the cross-entropy loss. (#287)
This commit is contained in:
@ -26,3 +26,20 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
.sum_all()?
|
||||
.affine(-1f64 / b_sz as f64, 0.)
|
||||
}
|
||||
|
||||
/// The cross-entropy loss.
|
||||
///
|
||||
/// Arguments
|
||||
///
|
||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||
/// of categories. This is expected to raw logits.
|
||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
||||
///
|
||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||
pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
if inp.rank() != 2 {
|
||||
candle::bail!("cross_entropy expects an input tensor of rank 2")
|
||||
}
|
||||
let inp = crate::ops::log_softmax(inp, 1)?;
|
||||
nll(&inp, target)
|
||||
}
|
||||
|
Reference in New Issue
Block a user