diff --git a/src/tensor.rs b/src/tensor.rs index a4ad93ec..b40ed886 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -316,18 +316,21 @@ impl Tensor { } let mut dims = dims.to_vec(); dims[dim] = length; - let shape = Shape::from(dims); - let mut storage = self.device().zeros(&shape, self.dtype())?; - let src_offset = self.stride[dim] * start; - // TODO: This is incorrect, see the currently wrong test in tensor_tests.rs - self.storage - .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, src_offset)?; + let adjusted_shape = Shape::from(dims); + let mut storage = self.device().zeros(&adjusted_shape, self.dtype())?; + self.storage.copy_strided_src( + &mut storage, + /* dst_offset= */ 0, + &adjusted_shape, + &self.stride, + /* src_offest= */ self.stride[dim] * start, + )?; let op = if self.track_op() { Some(Op::Narrow(self.clone(), dim, start, length)) } else { None }; - Ok(from_storage(storage, shape, op, false)) + Ok(from_storage(storage, adjusted_shape, op, false)) } pub fn softmax(&self, dim: usize) -> Result { @@ -881,7 +884,6 @@ impl Tensor { pub fn backward(&self) -> Result { let sorted_nodes = self.sorted_nodes(); - println!("{}", sorted_nodes.len()); let mut grads = GradStore::new(); grads.insert(self, self.ones_like()?); for node in sorted_nodes.iter() { diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index ce609de9..39f605d2 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -114,8 +114,7 @@ fn narrow() -> Result<()> { let tensor = Tensor::new(data, &Device::Cpu)?; assert_eq!( tensor.narrow(2, 1, 2)?.to_vec3::()?, - // TODO: this is broken at the moment! - &[[[1., 4.], [1., 5.]], [[9., 2.], [1., 7.]]] + &[[[1.0, 4.0], [5.0, 9.0]], [[1.0, 7.0], [2.0, 8.0]]], ); Ok(()) }