From bb6450ebbb5f9b14df7b88964cc73a4f1f0a592c Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 25 Jun 2023 14:20:42 +0100 Subject: [PATCH] Bugfix for Tensor::cat + add some tests. --- src/cpu_backend.rs | 4 ++-- src/tensor.rs | 5 ++++- tests/tensor_tests.rs | 17 +++++++++++++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index b92fd931..73ace860 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -98,8 +98,8 @@ fn copy_strided_src_( ) { let src = &src[src_offset..]; if src_shape.is_contiguous(src_stride) { - let elem_to_copy = dst.len() - dst_offset; - dst[dst_offset..].copy_from_slice(&src[..elem_to_copy]) + let elem_to_copy = (dst.len() - dst_offset).min(src.len()); + dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy]) } else { let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); for (dst_index, src_index) in src_indexes.enumerate() { diff --git a/src/tensor.rs b/src/tensor.rs index 83598962..6e9a547a 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -806,6 +806,9 @@ impl Tensor { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); } + if args.len() == 1 { + return Ok(args[0].clone()); + } let rank = args[0].rank(); if dim >= rank { return Err(Error::UnexpectedNumberOfDims { @@ -875,7 +878,7 @@ impl Tensor { let mut storage = device.zeros(&shape, dtype)?; for (arg, &offset) in args.iter().zip(offsets.iter()) { arg.storage - .copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)? + .copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?; } Ok(from_storage(storage, shape, op, false)) } diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index b4f8bdd0..aa3caf47 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -173,3 +173,20 @@ fn broadcast() -> Result<()> { ); Ok(()) } + +#[test] +fn cat() -> Result<()> { + let t1 = Tensor::new(&[3f32, 1., 4.], &Device::Cpu)?; + let t2 = Tensor::new(&[1f32, 5., 9., 2.], &Device::Cpu)?; + let t3 = Tensor::new(&[6f32, 5., 3., 5., 8., 9.], &Device::Cpu)?; + assert_eq!(Tensor::cat(&[&t1], 0)?.to_vec1::()?, [3f32, 1., 4.],); + assert_eq!( + Tensor::cat(&[&t1, &t2], 0)?.to_vec1::()?, + [3f32, 1., 4., 1., 5., 9., 2.], + ); + assert_eq!( + Tensor::cat(&[&t1, &t2, &t3], 0)?.to_vec1::()?, + [3f32, 1., 4., 1., 5., 9., 2., 6., 5., 3., 5., 8., 9.], + ); + Ok(()) +}