diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 0f48dc62..e181f240 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1370,6 +1370,10 @@ impl Tensor { self.sum(dims) } + pub fn mean_all(&self) -> Result { + self.sum_all()? / self.elem_count() as f64 + } + fn flatten_( &self, start_dim: Option, diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index 9d15719f..cddf278e 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -43,3 +43,8 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result { let inp = crate::ops::log_softmax(inp, 1)?; nll(&inp, target) } + +/// The mean squared error loss. +pub fn mse(inp: &Tensor, target: &Tensor) -> Result { + (inp - target)?.sqr()?.mean_all() +}