Fix the cat implementation + more testing.

This commit is contained in:
laurent
2023-06-25 15:32:13 +01:00
parent 118cc30908
commit 8b67f294e8
2 changed files with 35 additions and 12 deletions

View File

@ -210,18 +210,18 @@ fn cat() -> Result<()> {
.t()?
.to_vec2::<f32>()?,
[
[3.0, 4.0, 5.0, 5.0, 5.0],
[2.0, 1.0, 2.0, 7.0, 8.0],
[1.0, 1.0, 5.0, 5.0, 5.0],
[7.0, 8.0, 2.0, 1.0, 2.0]
[3.0, 1.0, 4.0, 1.0, 5.0],
[2.0, 7.0, 1.0, 8.0, 2.0],
[5.0, 5.0, 5.0, 5.0, 5.0],
[2.0, 7.0, 1.0, 8.0, 2.0]
]
);
// TODO: This is not the expected answer, to be fixed!
assert_eq!(
Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?,
[
[3.0, 1.0, 4.0, 1.0, 5.0, 2.0, 7.0, 1.0, 8.0, 2.0],
[5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 7.0, 1.0, 8.0, 2.0]
[3.0, 1.0, 4.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0],
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
]
);
Ok(())