mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the mse loss. (#723)
This commit is contained in:
@ -1370,6 +1370,10 @@ impl Tensor {
|
||||
self.sum(dims)
|
||||
}
|
||||
|
||||
pub fn mean_all(&self) -> Result<Tensor> {
|
||||
self.sum_all()? / self.elem_count() as f64
|
||||
}
|
||||
|
||||
fn flatten_<D1: Dim, D2: Dim>(
|
||||
&self,
|
||||
start_dim: Option<D1>,
|
||||
|
@ -43,3 +43,8 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
let inp = crate::ops::log_softmax(inp, 1)?;
|
||||
nll(&inp, target)
|
||||
}
|
||||
|
||||
/// The mean squared error loss.
|
||||
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
(inp - target)?.sqr()?.mean_all()
|
||||
}
|
||||
|
Reference in New Issue
Block a user