Add the mse loss. (#723)

This commit is contained in:
Laurent Mazare
2023-09-03 11:51:40 +02:00
committed by GitHub
parent 84d003ff53
commit 74a82c358a
2 changed files with 9 additions and 0 deletions

View File

@ -1370,6 +1370,10 @@ impl Tensor {
self.sum(dims)
}
pub fn mean_all(&self) -> Result<Tensor> {
self.sum_all()? / self.elem_count() as f64
}
fn flatten_<D1: Dim, D2: Dim>(
&self,
start_dim: Option<D1>,