From 9541467d6bef38263afaa33c78374cd37e3d659f Mon Sep 17 00:00:00 2001 From: Bryan Lee Date: Tue, 1 Apr 2025 03:07:16 -0400 Subject: [PATCH] Add `flip` to `tensor` (#2855) * Add `flip` to `tensor` * Move the tests to the proper places. --------- Co-authored-by: laurent --- candle-core/src/tensor.rs | 22 +++++++++++++ candle-core/src/test_utils.rs | 9 ++++++ candle-core/tests/grad_tests.rs | 32 ++++++++++++++++++- candle-core/tests/tensor_tests.rs | 51 +++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 1 deletion(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 31699288..6a06836d 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2580,6 +2580,28 @@ impl Tensor { pub fn broadcast_pow(&self, rhs: &Tensor) -> Result { rhs.broadcast_mul(&self.log()?)?.exp() } + + /// Returns a new tensor with the order of elements reversed along the specified dimensions. + /// This function makes a copy of the tensor’s data. + /// + /// ```rust + /// # use candle_core::{Tensor, Device}; + /// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?; + /// assert_eq!(t.to_vec2::()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + /// let t_flipped = t.flip(&[0])?; + /// assert_eq!(t_flipped.to_vec2::()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn flip(&self, dims: &[usize]) -> Result { + let mut result = self.clone(); + for &dim in dims.iter() { + let size = result.dim(dim)?; + let indices: Vec = (0..size).rev().map(|x| x as i64).collect(); + let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?; + result = result.index_select(&indices_tensor, dim)?; + } + Ok(result) + } } macro_rules! bin_trait { diff --git a/candle-core/src/test_utils.rs b/candle-core/src/test_utils.rs index 3b8fb904..e331399f 100644 --- a/candle-core/src/test_utils.rs +++ b/candle-core/src/test_utils.rs @@ -24,6 +24,15 @@ macro_rules! test_device { }; } +pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> { + assert_eq!(t1.shape(), t2.shape()); + // Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`) + let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?; + let all_equal = eq_tensor.sum_all()?; + assert_eq!(all_equal.to_scalar::()?, eq_tensor.elem_count() as u32); + Ok(()) +} + pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result { let b = 10f32.powi(digits); let t = t.to_vec0::()?; diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index b8b6be8d..b5e4e280 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -1,6 +1,6 @@ #![allow(clippy::approx_constant)] use anyhow::{Context, Result}; -use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var}; +use candle_core::{test_device, test_utils, DType, Device, Shape, Tensor, Var}; fn simple_grad(device: &Device) -> Result<()> { let x = Var::new(&[3f32, 1., 4.], device)?; @@ -505,6 +505,36 @@ fn binary_grad(device: &Device) -> Result<()> { Ok(()) } +#[test] +fn test_flip_backprop() -> Result<()> { + let device = &Device::Cpu; + + // Create a tensor (leaf node) that requires gradients + let x = Var::ones((2, 2), DType::F64, device)?; + let weights = Tensor::arange(1.0, 5.0, device)?.reshape((2, 2))?; + + let y = x.matmul(&weights)?; + let expected_y = Tensor::from_vec(vec![4.0, 6.0, 4.0, 6.0], (2, 2), device)?; + candle_core::test_utils::assert_tensor_eq(&y, &expected_y)?; + + let z = y.flip(&[1])?; + let expected_z = Tensor::from_vec(vec![6.0, 4.0, 6.0, 4.0], (2, 2), device)?; + candle_core::test_utils::assert_tensor_eq(&z, &expected_z)?; + + let loss = z.sum_all()?; + + let grad_store = loss.backward()?; + let grad_x = grad_store.get_id(x.id()).unwrap(); + + let flipped_weights = weights.flip(&[1])?; + let dloss_dy = Tensor::ones((2, 2), DType::F64, device)?; + // dloss/dx = dloss/dy @ dy/dx = ones @ weight.flip.T + let expected_grad = dloss_dy.matmul(&flipped_weights.t()?)?; + candle_core::test_utils::assert_tensor_eq(grad_x, &expected_grad)?; + + Ok(()) +} + test_device!( simple_grad, simple_grad_cpu, diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 17238dcd..36942ff2 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1682,3 +1682,54 @@ fn pow() -> Result<()> { ); Ok(()) } + +#[test] +fn test_flip_1d() -> Result<()> { + // 1D: [0, 1, 2, 3, 4] + let t = Tensor::arange(0.0, 5.0, &Device::Cpu)?.reshape((5,))?; + let flipped = t.flip(&[0])?; + // Expected: [4, 3, 2, 1, 0] + let expected = Tensor::from_vec(vec![4.0, 3.0, 2.0, 1.0, 0.0], (5,), &Device::Cpu)?; + candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?; + Ok(()) +} + +#[test] +fn test_flip_2d() -> Result<()> { + // 2D: + // [[0, 1, 2], + // [3, 4, 5]] + let t = Tensor::arange(0.0, 6.0, &Device::Cpu)?.reshape((2, 3))?; + let flipped = t.flip(&[0, 1])?; + // Expected: + // [[5, 4, 3], + // [2, 1, 0]] + let expected = Tensor::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0], (2, 3), &Device::Cpu)?; + candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?; + Ok(()) +} + +#[test] +fn test_flip_3d_channels() -> Result<()> { + // 3D: + // [[[0,1,2], + // [3,4,5]], + // + // [[6,7,8], + // [9,10,11]]] + let t = Tensor::arange(0.0, 12.0, &Device::Cpu)?.reshape((2, 2, 3))?; + let flipped = t.flip(&[2])?; + // Expected: + // [[[2,1,0], + // [5,4,3]], + // + // [[8,7,6], + // [11,10,9]]] + let expected = Tensor::from_vec( + vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0, 8.0, 7.0, 6.0, 11.0, 10.0, 9.0], + (2, 2, 3), + &Device::Cpu, + )?; + candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?; + Ok(()) +}