diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 39f605d2..a12f2b4d 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -116,5 +116,25 @@ fn narrow() -> Result<()> { tensor.narrow(2, 1, 2)?.to_vec3::()?, &[[[1.0, 4.0], [5.0, 9.0]], [[1.0, 7.0], [2.0, 8.0]]], ); + assert_eq!( + tensor.narrow(1, 1, 1)?.to_vec3::()?, + &[[[1.0, 5.0, 9.0]], [[8.0, 2.0, 8.0]]], + ); + assert_eq!( + tensor.narrow(0, 0, 1)?.to_vec3::()?, + &[[[3.0, 1.0, 4.0], [1.0, 5.0, 9.0]]], + ); + assert_eq!( + tensor.narrow(0, 1, 1)?.to_vec3::()?, + &[[[2.0, 1.0, 7.0], [8.0, 2.0, 8.0]]], + ); + // The following has been checked against PyTorch via: + // import torch + // t = torch.tensor([[[3., 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]) + // t.transpose(-1, -2).narrow(1, 1, 2) + assert_eq!( + tensor.t()?.narrow(1, 1, 2)?.to_vec3::()?, + &[[[1.0, 5.0], [4.0, 9.0]], [[1.0, 2.0], [7.0, 8.0]]], + ); Ok(()) }