diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 81c2e801..d1452164 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -53,3 +53,19 @@ fn binary_op() -> Result<()> { assert_eq!(content[0], [0., 0., 0., 0., 0.]); Ok(()) } + +#[test] +fn tensor_2d_transpose() -> Result<()> { + let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let tensor = Tensor::new(data, &Device::Cpu)?.t()?; + let dims = tensor.shape().r2()?; + assert_eq!(dims, (5, 2)); + assert_eq!( + tensor.to_vec2::()?, + &[[3f32, 2.], [1., 1.], [4., 7.], [1., 8.], [5., 2.]] + ); + assert_eq!(tensor.t()?.to_vec2::()?, data); + assert_eq!(tensor.contiguous()?.t()?.to_vec2::()?, data); + assert_eq!(((tensor + 1.)?.t()? - 1.)?.to_vec2::()?, data); + Ok(()) +}