Add logsumexp function (#1424)

This commit is contained in:
Wenqing Zong
2023-12-12 16:32:17 +00:00
committed by GitHub
parent 18eb87f25f
commit 77252ffb82
2 changed files with 33 additions and 1 deletions

View File

@ -2565,6 +2565,13 @@ impl Tensor {
}
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
}
/// Returns log(sum(exp(tensor), dim)).
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?;
let sum = exp.sum(sum_dims)?;
sum.log()
}
}
macro_rules! bin_trait {

View File

@ -1,4 +1,4 @@
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
use candle_core::{test_device, test_utils, D, DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@ -1221,3 +1221,28 @@ fn cumsum() -> Result<()> {
);
Ok(())
}
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) {
let a_vec: Vec<f64> = a.to_vec1().unwrap();
let b_vec: Vec<f64> = b.to_vec1().unwrap();
assert_eq!(a_vec.len(), b_vec.len());
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
assert!((a - b).abs() < epsilon);
}
}
#[test]
fn logsumexp() -> Result<()> {
let input = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let output = input.logsumexp(D::Minus1)?;
// Expectation get from pytorch.
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
assert_close(&output, &expected, 0.00001);
Ok(())
}