Move some shared functions to the nn module. (#221)

This commit is contained in:
Laurent Mazare
2023-07-22 14:25:11 +02:00
committed by GitHub
parent 43c7223292
commit 1f26042693
4 changed files with 24 additions and 19 deletions

8
candle-nn/src/loss.rs Normal file
View File

@ -0,0 +1,8 @@
use candle::{Result, Tensor};
pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let b_sz = target.dim(0)?;
inp.gather(target, 1)?
.sum_all()?
.affine(-1f64 / b_sz as f64, 0.)
}