Add flip to tensor (#2855)

* Add `flip` to `tensor`

* Move the tests to the proper places.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Bryan Lee
2025-04-01 03:07:16 -04:00
committed by GitHub
parent 6429609090
commit 9541467d6b
4 changed files with 113 additions and 1 deletions

View File

@ -1,6 +1,6 @@
#![allow(clippy::approx_constant)]
use anyhow::{Context, Result};
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
use candle_core::{test_device, test_utils, DType, Device, Shape, Tensor, Var};
fn simple_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., 4.], device)?;
@ -505,6 +505,36 @@ fn binary_grad(device: &Device) -> Result<()> {
Ok(())
}
#[test]
fn test_flip_backprop() -> Result<()> {
let device = &Device::Cpu;
// Create a tensor (leaf node) that requires gradients
let x = Var::ones((2, 2), DType::F64, device)?;
let weights = Tensor::arange(1.0, 5.0, device)?.reshape((2, 2))?;
let y = x.matmul(&weights)?;
let expected_y = Tensor::from_vec(vec![4.0, 6.0, 4.0, 6.0], (2, 2), device)?;
candle_core::test_utils::assert_tensor_eq(&y, &expected_y)?;
let z = y.flip(&[1])?;
let expected_z = Tensor::from_vec(vec![6.0, 4.0, 6.0, 4.0], (2, 2), device)?;
candle_core::test_utils::assert_tensor_eq(&z, &expected_z)?;
let loss = z.sum_all()?;
let grad_store = loss.backward()?;
let grad_x = grad_store.get_id(x.id()).unwrap();
let flipped_weights = weights.flip(&[1])?;
let dloss_dy = Tensor::ones((2, 2), DType::F64, device)?;
// dloss/dx = dloss/dy @ dy/dx = ones @ weight.flip.T
let expected_grad = dloss_dy.matmul(&flipped_weights.t()?)?;
candle_core::test_utils::assert_tensor_eq(grad_x, &expected_grad)?;
Ok(())
}
test_device!(
simple_grad,
simple_grad_cpu,

View File

@ -1682,3 +1682,54 @@ fn pow() -> Result<()> {
);
Ok(())
}
#[test]
fn test_flip_1d() -> Result<()> {
// 1D: [0, 1, 2, 3, 4]
let t = Tensor::arange(0.0, 5.0, &Device::Cpu)?.reshape((5,))?;
let flipped = t.flip(&[0])?;
// Expected: [4, 3, 2, 1, 0]
let expected = Tensor::from_vec(vec![4.0, 3.0, 2.0, 1.0, 0.0], (5,), &Device::Cpu)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn test_flip_2d() -> Result<()> {
// 2D:
// [[0, 1, 2],
// [3, 4, 5]]
let t = Tensor::arange(0.0, 6.0, &Device::Cpu)?.reshape((2, 3))?;
let flipped = t.flip(&[0, 1])?;
// Expected:
// [[5, 4, 3],
// [2, 1, 0]]
let expected = Tensor::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0], (2, 3), &Device::Cpu)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn test_flip_3d_channels() -> Result<()> {
// 3D:
// [[[0,1,2],
// [3,4,5]],
//
// [[6,7,8],
// [9,10,11]]]
let t = Tensor::arange(0.0, 12.0, &Device::Cpu)?.reshape((2, 2, 3))?;
let flipped = t.flip(&[2])?;
// Expected:
// [[[2,1,0],
// [5,4,3]],
//
// [[8,7,6],
// [11,10,9]]]
let expected = Tensor::from_vec(
vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0, 8.0, 7.0, 6.0, 11.0, 10.0, 9.0],
(2, 2, 3),
&Device::Cpu,
)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}