mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add logsumexp function (#1424)
This commit is contained in:
@ -2565,6 +2565,13 @@ impl Tensor {
|
||||
}
|
||||
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
|
||||
}
|
||||
|
||||
/// Returns log(sum(exp(tensor), dim)).
|
||||
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||
let exp = self.exp()?;
|
||||
let sum = exp.sum(sum_dims)?;
|
||||
sum.log()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
Reference in New Issue
Block a user