mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the cross-entropy loss. (#287)
This commit is contained in:
@ -26,3 +26,20 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
|||||||
.sum_all()?
|
.sum_all()?
|
||||||
.affine(-1f64 / b_sz as f64, 0.)
|
.affine(-1f64 / b_sz as f64, 0.)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The cross-entropy loss.
|
||||||
|
///
|
||||||
|
/// Arguments
|
||||||
|
///
|
||||||
|
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||||
|
/// of categories. This is expected to raw logits.
|
||||||
|
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
||||||
|
///
|
||||||
|
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||||
|
pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||||
|
if inp.rank() != 2 {
|
||||||
|
candle::bail!("cross_entropy expects an input tensor of rank 2")
|
||||||
|
}
|
||||||
|
let inp = crate::ops::log_softmax(inp, 1)?;
|
||||||
|
nll(&inp, target)
|
||||||
|
}
|
||||||
|
@ -10,9 +10,10 @@ input = torch.tensor([
|
|||||||
|
|
||||||
target = torch.tensor([1, 0, 4])
|
target = torch.tensor([1, 0, 4])
|
||||||
print(F.nll_loss(F.log_softmax(input, dim=1), target))
|
print(F.nll_loss(F.log_softmax(input, dim=1), target))
|
||||||
|
print(F.cross_entropy(input, target))
|
||||||
*/
|
*/
|
||||||
#[test]
|
#[test]
|
||||||
fn nll() -> Result<()> {
|
fn nll_and_cross_entropy() -> Result<()> {
|
||||||
let cpu = Device::Cpu;
|
let cpu = Device::Cpu;
|
||||||
let input = Tensor::new(
|
let input = Tensor::new(
|
||||||
&[
|
&[
|
||||||
@ -27,5 +28,7 @@ fn nll() -> Result<()> {
|
|||||||
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
|
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
|
||||||
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
|
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
|
||||||
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
||||||
|
let loss = candle_nn::loss::cross_entropy(&input, &target)?;
|
||||||
|
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user