mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the mse loss. (#723)
This commit is contained in:
@ -43,3 +43,8 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
let inp = crate::ops::log_softmax(inp, 1)?;
|
||||
nll(&inp, target)
|
||||
}
|
||||
|
||||
/// The mean squared error loss.
|
||||
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
(inp - target)?.sqr()?.mean_all()
|
||||
}
|
||||
|
Reference in New Issue
Block a user