diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 4d9b0837..8950f2c5 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2565,6 +2565,13 @@ impl Tensor { } mask.where_cond(/* on_true= */ &src, /* on_false= */ self) } + + /// Returns log(sum(exp(tensor), dim)). + pub fn logsumexp(&self, sum_dims: D) -> Result { + let exp = self.exp()?; + let sum = exp.sum(sum_dims)?; + sum.log() + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c871dc96..95eadc24 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1,4 +1,4 @@ -use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor}; +use candle_core::{test_device, test_utils, D, DType, Device, IndexOp, Result, Tensor}; fn zeros(device: &Device) -> Result<()> { let tensor = Tensor::zeros((5, 2), DType::F32, device)?; @@ -1221,3 +1221,28 @@ fn cumsum() -> Result<()> { ); Ok(()) } + +/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data. +/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon. +fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) { + let a_vec: Vec = a.to_vec1().unwrap(); + let b_vec: Vec = b.to_vec1().unwrap(); + + assert_eq!(a_vec.len(), b_vec.len()); + for (a, b) in a_vec.iter().zip(b_vec.iter()) { + assert!((a - b).abs() < epsilon); + } +} + +#[test] +fn logsumexp() -> Result<()> { + let input = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let output = input.logsumexp(D::Minus1)?; + + // Expectation get from pytorch. + let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; + + assert_close(&output, &expected, 0.00001); + + Ok(()) +} \ No newline at end of file