Bugfix for Tensor::cat + add some tests.

This commit is contained in:
laurent
2023-06-25 14:20:42 +01:00
parent 90c140ff4b
commit bb6450ebbb
3 changed files with 23 additions and 3 deletions

View File

@ -173,3 +173,20 @@ fn broadcast() -> Result<()> {
);
Ok(())
}
#[test]
fn cat() -> Result<()> {
let t1 = Tensor::new(&[3f32, 1., 4.], &Device::Cpu)?;
let t2 = Tensor::new(&[1f32, 5., 9., 2.], &Device::Cpu)?;
let t3 = Tensor::new(&[6f32, 5., 3., 5., 8., 9.], &Device::Cpu)?;
assert_eq!(Tensor::cat(&[&t1], 0)?.to_vec1::<f32>()?, [3f32, 1., 4.],);
assert_eq!(
Tensor::cat(&[&t1, &t2], 0)?.to_vec1::<f32>()?,
[3f32, 1., 4., 1., 5., 9., 2.],
);
assert_eq!(
Tensor::cat(&[&t1, &t2, &t3], 0)?.to_vec1::<f32>()?,
[3f32, 1., 4., 1., 5., 9., 2., 6., 5., 3., 5., 8., 9.],
);
Ok(())
}