mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add flip
to tensor
(#2855)
* Add `flip` to `tensor` * Move the tests to the proper places. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -2580,6 +2580,28 @@ impl Tensor {
|
||||
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||
}
|
||||
|
||||
/// Returns a new tensor with the order of elements reversed along the specified dimensions.
|
||||
/// This function makes a copy of the tensor’s data.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, Device};
|
||||
/// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
|
||||
/// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
/// let t_flipped = t.flip(&[0])?;
|
||||
/// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
|
||||
let mut result = self.clone();
|
||||
for &dim in dims.iter() {
|
||||
let size = result.dim(dim)?;
|
||||
let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
|
||||
let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
|
||||
result = result.index_select(&indices_tensor, dim)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -24,6 +24,15 @@ macro_rules! test_device {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
|
||||
assert_eq!(t1.shape(), t2.shape());
|
||||
// Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
|
||||
let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
|
||||
let all_equal = eq_tensor.sum_all()?;
|
||||
assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec0::<f32>()?;
|
||||
|
Reference in New Issue
Block a user