mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00

* add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting
89 lines
2.4 KiB
Rust
89 lines
2.4 KiB
Rust
#[cfg(feature = "mkl")]
|
|
extern crate intel_mkl_src;
|
|
|
|
#[cfg(feature = "accelerate")]
|
|
extern crate accelerate_src;
|
|
|
|
use candle::test_utils::to_vec0_round;
|
|
use candle::{Device, Result, Tensor};
|
|
|
|
/* Equivalent python code:
|
|
import torch
|
|
import torch.nn.functional as F
|
|
input = torch.tensor([
|
|
[ 1.1050, 0.3013, -1.5394, -2.1528, -0.8634],
|
|
[ 1.0730, -0.9419, -0.1670, -0.6582, 0.5061],
|
|
[ 0.8318, 1.1154, -0.3610, 0.5351, 1.0830]])
|
|
|
|
target = torch.tensor([1, 0, 4])
|
|
print(F.nll_loss(F.log_softmax(input, dim=1), target))
|
|
print(F.cross_entropy(input, target))
|
|
*/
|
|
#[test]
|
|
fn nll_and_cross_entropy() -> Result<()> {
|
|
let cpu = Device::Cpu;
|
|
let input = Tensor::new(
|
|
&[
|
|
[1.1050f32, 0.3013, -1.5394, -2.1528, -0.8634],
|
|
[1.0730, -0.9419, -0.1670, -0.6582, 0.5061],
|
|
[0.8318, 1.1154, -0.3610, 0.5351, 1.0830],
|
|
],
|
|
&cpu,
|
|
)?;
|
|
let target = Tensor::new(&[1u32, 0, 4], &cpu)?;
|
|
|
|
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
|
|
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
|
|
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
|
|
let loss = candle_nn::loss::cross_entropy(&input, &target)?;
|
|
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
|
|
Ok(())
|
|
}
|
|
|
|
/* Equivalent python code:
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],
|
|
[ 0.0419, 0.0763, -1.0457, -1.6692],
|
|
[-1.0494, 0.8111, 1.5723, 1.2315],
|
|
[ 1.3081, 0.6641, 1.1802, -0.2547],
|
|
[ 0.5292, 0.7636, 0.3692, -0.8318]])
|
|
|
|
target = torch.Tensor([[0., 1., 0., 0.],
|
|
[0., 1., 0., 0.],
|
|
[0., 0., 0., 1.],
|
|
[1., 0., 0., 0.],
|
|
[0., 0., 1., 0.]])
|
|
|
|
print(F.binary_cross_entropy_with_logits(inp, target))
|
|
*/
|
|
#[test]
|
|
fn binary_cross_entropy_with_logit() -> Result<()> {
|
|
let cpu = Device::Cpu;
|
|
|
|
let inp = [
|
|
[2.3611f32, -0.8813, -0.5006, -0.2178],
|
|
[0.0419, 0.0763, -1.0457, -1.6692],
|
|
[-1.0494, 0.8111, 1.5723, 1.2315],
|
|
[1.3081, 0.6641, 1.1802, -0.2547],
|
|
[0.5292, 0.7636, 0.3692, -0.8318],
|
|
];
|
|
|
|
let target = [
|
|
[0.0f32, 1., 0., 0.],
|
|
[0., 1., 0., 0.],
|
|
[0., 0., 0., 1.],
|
|
[1., 0., 0., 0.],
|
|
[0., 0., 1., 0.],
|
|
];
|
|
|
|
let inp = Tensor::new(&inp, &cpu)?;
|
|
let target = Tensor::new(&target, &cpu)?;
|
|
|
|
let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?;
|
|
|
|
assert_eq!(to_vec0_round(&loss, 4)?, 0.8224);
|
|
Ok(())
|
|
}
|