mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix the logsumexp test. (#1426)
This commit is contained in:
@ -1,4 +1,4 @@
|
|||||||
use candle_core::{test_device, test_utils, D, DType, Device, IndexOp, Result, Tensor};
|
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
|
||||||
|
|
||||||
fn zeros(device: &Device) -> Result<()> {
|
fn zeros(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||||
@ -1224,25 +1224,23 @@ fn cumsum() -> Result<()> {
|
|||||||
|
|
||||||
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
|
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
|
||||||
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
|
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
|
||||||
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) {
|
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
||||||
let a_vec: Vec<f64> = a.to_vec1().unwrap();
|
let a_vec: Vec<f64> = a.to_vec1()?;
|
||||||
let b_vec: Vec<f64> = b.to_vec1().unwrap();
|
let b_vec: Vec<f64> = b.to_vec1()?;
|
||||||
|
|
||||||
assert_eq!(a_vec.len(), b_vec.len());
|
assert_eq!(a_vec.len(), b_vec.len());
|
||||||
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
|
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
|
||||||
assert!((a - b).abs() < epsilon);
|
assert!((a - b).abs() < epsilon);
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn logsumexp() -> Result<()> {
|
fn logsumexp() -> Result<()> {
|
||||||
let input = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||||
let output = input.logsumexp(D::Minus1)?;
|
let output = input.logsumexp(D::Minus1)?;
|
||||||
|
// The expectations obtained from pytorch.
|
||||||
// Expectation get from pytorch.
|
|
||||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
||||||
|
assert_close(&output, &expected, 0.00001)?;
|
||||||
assert_close(&output, &expected, 0.00001);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
Reference in New Issue
Block a user