mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the mse loss. (#723)
This commit is contained in:
@ -1370,6 +1370,10 @@ impl Tensor {
|
|||||||
self.sum(dims)
|
self.sum(dims)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn mean_all(&self) -> Result<Tensor> {
|
||||||
|
self.sum_all()? / self.elem_count() as f64
|
||||||
|
}
|
||||||
|
|
||||||
fn flatten_<D1: Dim, D2: Dim>(
|
fn flatten_<D1: Dim, D2: Dim>(
|
||||||
&self,
|
&self,
|
||||||
start_dim: Option<D1>,
|
start_dim: Option<D1>,
|
||||||
|
@ -43,3 +43,8 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
|||||||
let inp = crate::ops::log_softmax(inp, 1)?;
|
let inp = crate::ops::log_softmax(inp, 1)?;
|
||||||
nll(&inp, target)
|
nll(&inp, target)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The mean squared error loss.
|
||||||
|
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||||
|
(inp - target)?.sqr()?.mean_all()
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user