From 1b5f892d732a51dc46cc035fad5c75418fd11320 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 24 Jun 2023 08:50:37 +0100 Subject: [PATCH] Add a currently wrong test for narrow. --- src/tensor.rs | 5 +++-- tests/tensor_tests.rs | 13 +++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index b58ed739..a4ad93ec 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -300,7 +300,7 @@ impl Tensor { /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// ranges from `start` to `start + length`. - // TODO: Once we've refactor the shape and strides, make this return a view of the same data + // TODO: Once we've refactored the shape and strides, make this return a view of the same data // rather than copying. pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { let dims = self.shape().dims(); @@ -318,7 +318,8 @@ impl Tensor { dims[dim] = length; let shape = Shape::from(dims); let mut storage = self.device().zeros(&shape, self.dtype())?; - let src_offset = 0; // TODO + let src_offset = self.stride[dim] * start; + // TODO: This is incorrect, see the currently wrong test in tensor_tests.rs self.storage .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, src_offset)?; let op = if self.track_op() { diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 0ffcad62..ce609de9 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -1,3 +1,4 @@ +// TODO: Also test the cuda backend. use candle::{DType, Device, Result, Tensor}; #[test] @@ -106,3 +107,15 @@ fn softmax() -> Result<()> { ); Ok(()) } + +#[test] +fn narrow() -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, &Device::Cpu)?; + assert_eq!( + tensor.narrow(2, 1, 2)?.to_vec3::()?, + // TODO: this is broken at the moment! + &[[[1., 4.], [1., 5.]], [[9., 2.], [1., 7.]]] + ); + Ok(()) +}