mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add Binary Cross Entropy With Logit Loss to nn crate (#1157)
* add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting
This commit is contained in:
@ -48,3 +48,25 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
(inp - target)?.sqr()?.mean_all()
|
||||
}
|
||||
|
||||
/// The binary cross-entropy with logit loss.
|
||||
///
|
||||
/// Arguments
|
||||
///
|
||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||
/// of categories. This is expected to raw logits.
|
||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
|
||||
/// of categories.
|
||||
///
|
||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
let inp = crate::ops::sigmoid(inp)?;
|
||||
|
||||
let left_side = target * inp.log()?;
|
||||
let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;
|
||||
|
||||
let loss = left_side? + right_side?;
|
||||
let loss = loss?.neg()?.mean_all()?;
|
||||
|
||||
Ok(loss)
|
||||
}
|
||||
|
@ -39,3 +39,50 @@ fn nll_and_cross_entropy() -> Result<()> {
|
||||
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/* Equivalent python code:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],
|
||||
[ 0.0419, 0.0763, -1.0457, -1.6692],
|
||||
[-1.0494, 0.8111, 1.5723, 1.2315],
|
||||
[ 1.3081, 0.6641, 1.1802, -0.2547],
|
||||
[ 0.5292, 0.7636, 0.3692, -0.8318]])
|
||||
|
||||
target = torch.Tensor([[0., 1., 0., 0.],
|
||||
[0., 1., 0., 0.],
|
||||
[0., 0., 0., 1.],
|
||||
[1., 0., 0., 0.],
|
||||
[0., 0., 1., 0.]])
|
||||
|
||||
print(F.binary_cross_entropy_with_logits(inp, target))
|
||||
*/
|
||||
#[test]
|
||||
fn binary_cross_entropy_with_logit() -> Result<()> {
|
||||
let cpu = Device::Cpu;
|
||||
|
||||
let inp = [
|
||||
[2.3611f32, -0.8813, -0.5006, -0.2178],
|
||||
[0.0419, 0.0763, -1.0457, -1.6692],
|
||||
[-1.0494, 0.8111, 1.5723, 1.2315],
|
||||
[1.3081, 0.6641, 1.1802, -0.2547],
|
||||
[0.5292, 0.7636, 0.3692, -0.8318],
|
||||
];
|
||||
|
||||
let target = [
|
||||
[0.0f32, 1., 0., 0.],
|
||||
[0., 1., 0., 0.],
|
||||
[0., 0., 0., 1.],
|
||||
[1., 0., 0., 0.],
|
||||
[0., 0., 1., 0.],
|
||||
];
|
||||
|
||||
let inp = Tensor::new(&inp, &cpu)?;
|
||||
let target = Tensor::new(&target, &cpu)?;
|
||||
|
||||
let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?;
|
||||
|
||||
assert_eq!(to_vec0_round(&loss, 4)?, 0.8224);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user