mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add flip
to tensor
(#2855)
* Add `flip` to `tensor` * Move the tests to the proper places. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -2580,6 +2580,28 @@ impl Tensor {
|
|||||||
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||||
rhs.broadcast_mul(&self.log()?)?.exp()
|
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a new tensor with the order of elements reversed along the specified dimensions.
|
||||||
|
/// This function makes a copy of the tensor’s data.
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use candle_core::{Tensor, Device};
|
||||||
|
/// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
|
||||||
|
/// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
|
/// let t_flipped = t.flip(&[0])?;
|
||||||
|
/// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
|
||||||
|
let mut result = self.clone();
|
||||||
|
for &dim in dims.iter() {
|
||||||
|
let size = result.dim(dim)?;
|
||||||
|
let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
|
||||||
|
let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
|
||||||
|
result = result.index_select(&indices_tensor, dim)?;
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -24,6 +24,15 @@ macro_rules! test_device {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
|
||||||
|
assert_eq!(t1.shape(), t2.shape());
|
||||||
|
// Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
|
||||||
|
let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
|
||||||
|
let all_equal = eq_tensor.sum_all()?;
|
||||||
|
assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||||
let b = 10f32.powi(digits);
|
let b = 10f32.powi(digits);
|
||||||
let t = t.to_vec0::<f32>()?;
|
let t = t.to_vec0::<f32>()?;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#![allow(clippy::approx_constant)]
|
#![allow(clippy::approx_constant)]
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
use candle_core::{test_device, test_utils, DType, Device, Shape, Tensor, Var};
|
||||||
|
|
||||||
fn simple_grad(device: &Device) -> Result<()> {
|
fn simple_grad(device: &Device) -> Result<()> {
|
||||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||||
@ -505,6 +505,36 @@ fn binary_grad(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_backprop() -> Result<()> {
|
||||||
|
let device = &Device::Cpu;
|
||||||
|
|
||||||
|
// Create a tensor (leaf node) that requires gradients
|
||||||
|
let x = Var::ones((2, 2), DType::F64, device)?;
|
||||||
|
let weights = Tensor::arange(1.0, 5.0, device)?.reshape((2, 2))?;
|
||||||
|
|
||||||
|
let y = x.matmul(&weights)?;
|
||||||
|
let expected_y = Tensor::from_vec(vec![4.0, 6.0, 4.0, 6.0], (2, 2), device)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&y, &expected_y)?;
|
||||||
|
|
||||||
|
let z = y.flip(&[1])?;
|
||||||
|
let expected_z = Tensor::from_vec(vec![6.0, 4.0, 6.0, 4.0], (2, 2), device)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&z, &expected_z)?;
|
||||||
|
|
||||||
|
let loss = z.sum_all()?;
|
||||||
|
|
||||||
|
let grad_store = loss.backward()?;
|
||||||
|
let grad_x = grad_store.get_id(x.id()).unwrap();
|
||||||
|
|
||||||
|
let flipped_weights = weights.flip(&[1])?;
|
||||||
|
let dloss_dy = Tensor::ones((2, 2), DType::F64, device)?;
|
||||||
|
// dloss/dx = dloss/dy @ dy/dx = ones @ weight.flip.T
|
||||||
|
let expected_grad = dloss_dy.matmul(&flipped_weights.t()?)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(grad_x, &expected_grad)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(
|
test_device!(
|
||||||
simple_grad,
|
simple_grad,
|
||||||
simple_grad_cpu,
|
simple_grad_cpu,
|
||||||
|
@ -1682,3 +1682,54 @@ fn pow() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_1d() -> Result<()> {
|
||||||
|
// 1D: [0, 1, 2, 3, 4]
|
||||||
|
let t = Tensor::arange(0.0, 5.0, &Device::Cpu)?.reshape((5,))?;
|
||||||
|
let flipped = t.flip(&[0])?;
|
||||||
|
// Expected: [4, 3, 2, 1, 0]
|
||||||
|
let expected = Tensor::from_vec(vec![4.0, 3.0, 2.0, 1.0, 0.0], (5,), &Device::Cpu)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_2d() -> Result<()> {
|
||||||
|
// 2D:
|
||||||
|
// [[0, 1, 2],
|
||||||
|
// [3, 4, 5]]
|
||||||
|
let t = Tensor::arange(0.0, 6.0, &Device::Cpu)?.reshape((2, 3))?;
|
||||||
|
let flipped = t.flip(&[0, 1])?;
|
||||||
|
// Expected:
|
||||||
|
// [[5, 4, 3],
|
||||||
|
// [2, 1, 0]]
|
||||||
|
let expected = Tensor::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0], (2, 3), &Device::Cpu)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_3d_channels() -> Result<()> {
|
||||||
|
// 3D:
|
||||||
|
// [[[0,1,2],
|
||||||
|
// [3,4,5]],
|
||||||
|
//
|
||||||
|
// [[6,7,8],
|
||||||
|
// [9,10,11]]]
|
||||||
|
let t = Tensor::arange(0.0, 12.0, &Device::Cpu)?.reshape((2, 2, 3))?;
|
||||||
|
let flipped = t.flip(&[2])?;
|
||||||
|
// Expected:
|
||||||
|
// [[[2,1,0],
|
||||||
|
// [5,4,3]],
|
||||||
|
//
|
||||||
|
// [[8,7,6],
|
||||||
|
// [11,10,9]]]
|
||||||
|
let expected = Tensor::from_vec(
|
||||||
|
vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0, 8.0, 7.0, 6.0, 11.0, 10.0, 9.0],
|
||||||
|
(2, 2, 3),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user