Add logsumexp function (#1424)

This commit is contained in:
Wenqing Zong
2023-12-12 16:32:17 +00:00
committed by GitHub
parent 18eb87f25f
commit 77252ffb82
2 changed files with 33 additions and 1 deletions

View File

@ -2565,6 +2565,13 @@ impl Tensor {
}
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
}
/// Returns log(sum(exp(tensor), dim)).
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?;
let sum = exp.sum(sum_dims)?;
sum.log()
}
}
macro_rules! bin_trait {