mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Bugfix for Tensor::cat + add some tests.
This commit is contained in:
@ -98,8 +98,8 @@ fn copy_strided_src_<T: Copy>(
|
|||||||
) {
|
) {
|
||||||
let src = &src[src_offset..];
|
let src = &src[src_offset..];
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_shape.is_contiguous(src_stride) {
|
||||||
let elem_to_copy = dst.len() - dst_offset;
|
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
||||||
dst[dst_offset..].copy_from_slice(&src[..elem_to_copy])
|
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy])
|
||||||
} else {
|
} else {
|
||||||
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
||||||
for (dst_index, src_index) in src_indexes.enumerate() {
|
for (dst_index, src_index) in src_indexes.enumerate() {
|
||||||
|
@ -806,6 +806,9 @@ impl Tensor {
|
|||||||
if args.is_empty() {
|
if args.is_empty() {
|
||||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||||
}
|
}
|
||||||
|
if args.len() == 1 {
|
||||||
|
return Ok(args[0].clone());
|
||||||
|
}
|
||||||
let rank = args[0].rank();
|
let rank = args[0].rank();
|
||||||
if dim >= rank {
|
if dim >= rank {
|
||||||
return Err(Error::UnexpectedNumberOfDims {
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
@ -875,7 +878,7 @@ impl Tensor {
|
|||||||
let mut storage = device.zeros(&shape, dtype)?;
|
let mut storage = device.zeros(&shape, dtype)?;
|
||||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||||
arg.storage
|
arg.storage
|
||||||
.copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?
|
.copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?;
|
||||||
}
|
}
|
||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
@ -173,3 +173,20 @@ fn broadcast() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cat() -> Result<()> {
|
||||||
|
let t1 = Tensor::new(&[3f32, 1., 4.], &Device::Cpu)?;
|
||||||
|
let t2 = Tensor::new(&[1f32, 5., 9., 2.], &Device::Cpu)?;
|
||||||
|
let t3 = Tensor::new(&[6f32, 5., 3., 5., 8., 9.], &Device::Cpu)?;
|
||||||
|
assert_eq!(Tensor::cat(&[&t1], 0)?.to_vec1::<f32>()?, [3f32, 1., 4.],);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::cat(&[&t1, &t2], 0)?.to_vec1::<f32>()?,
|
||||||
|
[3f32, 1., 4., 1., 5., 9., 2.],
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Tensor::cat(&[&t1, &t2, &t3], 0)?.to_vec1::<f32>()?,
|
||||||
|
[3f32, 1., 4., 1., 5., 9., 2., 6., 5., 3., 5., 8., 9.],
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user