Fix the cat implementation + more testing.

This commit is contained in:
laurent
2023-06-25 15:32:13 +01:00
parent 118cc30908
commit 8b67f294e8
2 changed files with 35 additions and 12 deletions

View File

@ -817,11 +817,34 @@ impl Tensor {
shape: args[0].shape().clone(), shape: args[0].shape().clone(),
}); });
} }
if dim == 0 {
Self::cat0(args)
} else {
// TODO: Avoid these transpositions and have an implementation that works
// for dim != 0...
let args: Vec<Tensor> = args
.iter()
.map(|a| a.transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let args: Vec<&Tensor> = args.iter().collect();
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
}
pub fn cat0(args: &[&Self]) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
}
if args.len() == 1 {
return Ok(args[0].clone());
}
let rank = args[0].rank();
let device = args[0].device(); let device = args[0].device();
let dtype = args[0].dtype(); let dtype = args[0].dtype();
let first_dims = args[0].shape().dims(); let first_dims = args[0].shape().dims();
let mut cat_dims = first_dims.to_vec(); let mut cat_dims = first_dims.to_vec();
cat_dims[dim] = 0; cat_dims[0] = 0;
let mut offsets = vec![0usize]; let mut offsets = vec![0usize];
for (arg_idx, arg) in args.iter().enumerate() { for (arg_idx, arg) in args.iter().enumerate() {
if arg.dtype() != dtype { if arg.dtype() != dtype {
@ -848,10 +871,10 @@ impl Tensor {
.zip(arg.shape().dims().iter()) .zip(arg.shape().dims().iter())
.enumerate() .enumerate()
{ {
if dim == dim_idx { if dim_idx == 0 {
cat_dims[dim] += v2; cat_dims[0] += v2;
} }
if dim != dim_idx && v1 != v2 { if dim_idx != 0 && v1 != v2 {
// TODO: It would probably be good to have a nicer error message here, i.e. // TODO: It would probably be good to have a nicer error message here, i.e.
// mention the problematic dimension and the values. // mention the problematic dimension and the values.
mismatch = true; mismatch = true;
@ -859,7 +882,7 @@ impl Tensor {
} }
if mismatch { if mismatch {
return Err(Error::ShapeMismatchCat { return Err(Error::ShapeMismatchCat {
dim, dim: 0, // TODO: not the appropriate error message
first_shape: args[0].shape().clone(), first_shape: args[0].shape().clone(),
n: arg_idx + 1, n: arg_idx + 1,
nth_shape: arg.shape().clone(), nth_shape: arg.shape().clone(),
@ -871,7 +894,7 @@ impl Tensor {
let shape = Shape::from(cat_dims); let shape = Shape::from(cat_dims);
let op = if args.iter().any(|arg| arg.track_op()) { let op = if args.iter().any(|arg| arg.track_op()) {
let args: Vec<Tensor> = args.iter().map(|&arg| arg.clone()).collect(); let args: Vec<Tensor> = args.iter().map(|&arg| arg.clone()).collect();
Some(Op::Cat(args, dim)) Some(Op::Cat(args, 0))
} else { } else {
None None
}; };

View File

@ -210,18 +210,18 @@ fn cat() -> Result<()> {
.t()? .t()?
.to_vec2::<f32>()?, .to_vec2::<f32>()?,
[ [
[3.0, 4.0, 5.0, 5.0, 5.0], [3.0, 1.0, 4.0, 1.0, 5.0],
[2.0, 1.0, 2.0, 7.0, 8.0], [2.0, 7.0, 1.0, 8.0, 2.0],
[1.0, 1.0, 5.0, 5.0, 5.0], [5.0, 5.0, 5.0, 5.0, 5.0],
[7.0, 8.0, 2.0, 1.0, 2.0] [2.0, 7.0, 1.0, 8.0, 2.0]
] ]
); );
// TODO: This is not the expected answer, to be fixed! // TODO: This is not the expected answer, to be fixed!
assert_eq!( assert_eq!(
Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?, Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?,
[ [
[3.0, 1.0, 4.0, 1.0, 5.0, 2.0, 7.0, 1.0, 8.0, 2.0], [3.0, 1.0, 4.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0],
[5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 7.0, 1.0, 8.0, 2.0] [2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
] ]
); );
Ok(()) Ok(())