mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add flip
to tensor
(#2855)
* Add `flip` to `tensor` * Move the tests to the proper places. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
#![allow(clippy::approx_constant)]
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||
use candle_core::{test_device, test_utils, DType, Device, Shape, Tensor, Var};
|
||||
|
||||
fn simple_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||
@ -505,6 +505,36 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flip_backprop() -> Result<()> {
|
||||
let device = &Device::Cpu;
|
||||
|
||||
// Create a tensor (leaf node) that requires gradients
|
||||
let x = Var::ones((2, 2), DType::F64, device)?;
|
||||
let weights = Tensor::arange(1.0, 5.0, device)?.reshape((2, 2))?;
|
||||
|
||||
let y = x.matmul(&weights)?;
|
||||
let expected_y = Tensor::from_vec(vec![4.0, 6.0, 4.0, 6.0], (2, 2), device)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&y, &expected_y)?;
|
||||
|
||||
let z = y.flip(&[1])?;
|
||||
let expected_z = Tensor::from_vec(vec![6.0, 4.0, 6.0, 4.0], (2, 2), device)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&z, &expected_z)?;
|
||||
|
||||
let loss = z.sum_all()?;
|
||||
|
||||
let grad_store = loss.backward()?;
|
||||
let grad_x = grad_store.get_id(x.id()).unwrap();
|
||||
|
||||
let flipped_weights = weights.flip(&[1])?;
|
||||
let dloss_dy = Tensor::ones((2, 2), DType::F64, device)?;
|
||||
// dloss/dx = dloss/dy @ dy/dx = ones @ weight.flip.T
|
||||
let expected_grad = dloss_dy.matmul(&flipped_weights.t()?)?;
|
||||
candle_core::test_utils::assert_tensor_eq(grad_x, &expected_grad)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
simple_grad,
|
||||
simple_grad_cpu,
|
||||
|
Reference in New Issue
Block a user