mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add flip
to tensor
(#2855)
* Add `flip` to `tensor` * Move the tests to the proper places. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -24,6 +24,15 @@ macro_rules! test_device {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
|
||||
assert_eq!(t1.shape(), t2.shape());
|
||||
// Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
|
||||
let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
|
||||
let all_equal = eq_tensor.sum_all()?;
|
||||
assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec0::<f32>()?;
|
||||
|
Reference in New Issue
Block a user