Fix the cpu implementation for narrow.

This commit is contained in:
laurent
2023-06-24 15:01:32 +01:00
parent 1b5f892d73
commit 0f34738831
2 changed files with 11 additions and 10 deletions

View File

@ -114,8 +114,7 @@ fn narrow() -> Result<()> {
let tensor = Tensor::new(data, &Device::Cpu)?;
assert_eq!(
tensor.narrow(2, 1, 2)?.to_vec3::<f32>()?,
// TODO: this is broken at the moment!
&[[[1., 4.], [1., 5.]], [[9., 2.], [1., 7.]]]
&[[[1.0, 4.0], [5.0, 9.0]], [[1.0, 7.0], [2.0, 8.0]]],
);
Ok(())
}