mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Add the cross-entropy loss. (#287)
This commit is contained in:
@ -10,9 +10,10 @@ input = torch.tensor([
|
||||
|
||||
target = torch.tensor([1, 0, 4])
|
||||
print(F.nll_loss(F.log_softmax(input, dim=1), target))
|
||||
print(F.cross_entropy(input, target))
|
||||
*/
|
||||
#[test]
|
||||
fn nll() -> Result<()> {
|
||||
fn nll_and_cross_entropy() -> Result<()> {
|
||||
let cpu = Device::Cpu;
|
||||
let input = Tensor::new(
|
||||
&[
|
||||
@ -27,5 +28,7 @@ fn nll() -> Result<()> {
|
||||
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
|
||||
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
|
||||
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
||||
let loss = candle_nn::loss::cross_entropy(&input, &target)?;
|
||||
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user