Add the mse loss. (#723)

This commit is contained in:
Laurent Mazare
2023-09-03 11:51:40 +02:00
committed by GitHub
parent 84d003ff53
commit 74a82c358a
2 changed files with 9 additions and 0 deletions

View File

@ -43,3 +43,8 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let inp = crate::ops::log_softmax(inp, 1)?;
nll(&inp, target)
}
/// The mean squared error loss.
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
(inp - target)?.sqr()?.mean_all()
}