Bugfix for Tensor::cat + add some tests.

This commit is contained in:
laurent
2023-06-25 14:20:42 +01:00
parent 90c140ff4b
commit bb6450ebbb
3 changed files with 23 additions and 3 deletions

View File

@ -98,8 +98,8 @@ fn copy_strided_src_<T: Copy>(
) {
let src = &src[src_offset..];
if src_shape.is_contiguous(src_stride) {
let elem_to_copy = dst.len() - dst_offset;
dst[dst_offset..].copy_from_slice(&src[..elem_to_copy])
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy])
} else {
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
for (dst_index, src_index) in src_indexes.enumerate() {

View File

@ -806,6 +806,9 @@ impl Tensor {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
}
if args.len() == 1 {
return Ok(args[0].clone());
}
let rank = args[0].rank();
if dim >= rank {
return Err(Error::UnexpectedNumberOfDims {
@ -875,7 +878,7 @@ impl Tensor {
let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) {
arg.storage
.copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?
.copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?;
}
Ok(from_storage(storage, shape, op, false))
}