mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add Binary Cross Entropy With Logit Loss to nn crate (#1157)
* add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting
This commit is contained in:
@ -48,3 +48,25 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
(inp - target)?.sqr()?.mean_all()
|
||||
}
|
||||
|
||||
/// The binary cross-entropy with logit loss.
|
||||
///
|
||||
/// Arguments
|
||||
///
|
||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||
/// of categories. This is expected to raw logits.
|
||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
|
||||
/// of categories.
|
||||
///
|
||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
let inp = crate::ops::sigmoid(inp)?;
|
||||
|
||||
let left_side = target * inp.log()?;
|
||||
let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;
|
||||
|
||||
let loss = left_side? + right_side?;
|
||||
let loss = loss?.neg()?.mean_all()?;
|
||||
|
||||
Ok(loss)
|
||||
}
|
||||
|
Reference in New Issue
Block a user