From 74a82c358a9fd160b52230a28d9929c32e95ecba Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 3 Sep 2023 11:51:40 +0200 Subject: [PATCH] Add the mse loss. (#723) --- candle-core/src/tensor.rs | 4 ++++ candle-nn/src/loss.rs | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 0f48dc62..e181f240 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1370,6 +1370,10 @@ impl Tensor { self.sum(dims) } + pub fn mean_all(&self) -> Result { + self.sum_all()? / self.elem_count() as f64 + } + fn flatten_( &self, start_dim: Option, diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index 9d15719f..cddf278e 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -43,3 +43,8 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result { let inp = crate::ops::log_softmax(inp, 1)?; nll(&inp, target) } + +/// The mean squared error loss. +pub fn mse(inp: &Tensor, target: &Tensor) -> Result { + (inp - target)?.sqr()?.mean_all() +}