Add flip to tensor (#2855)

* Add `flip` to `tensor`

* Move the tests to the proper places.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Bryan Lee
2025-04-01 03:07:16 -04:00
committed by GitHub
parent 6429609090
commit 9541467d6b
4 changed files with 113 additions and 1 deletions

View File

@ -24,6 +24,15 @@ macro_rules! test_device {
};
}
pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
assert_eq!(t1.shape(), t2.shape());
// Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
let all_equal = eq_tensor.sum_all()?;
assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
Ok(())
}
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
let b = 10f32.powi(digits);
let t = t.to_vec0::<f32>()?;