Add Binary Cross Entropy With Logit Loss to nn crate (#1157)

* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting
This commit is contained in:
Ogundepo Odunayo
2023-10-23 12:12:44 -04:00
committed by GitHub
parent 25c3cc4149
commit 86e1803191
2 changed files with 69 additions and 0 deletions

View File

@ -48,3 +48,25 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
(inp - target)?.sqr()?.mean_all()
}
/// The binary cross-entropy with logit loss.
///
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
/// of categories.
///
/// The resulting tensor is a scalar containing the average value over the batch.
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let inp = crate::ops::sigmoid(inp)?;
let left_side = target * inp.log()?;
let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;
let loss = left_side? + right_side?;
let loss = loss?.neg()?.mean_all()?;
Ok(loss)
}