From 8b67f294e8da32e1c6cd379f255ceade132d3b96 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 25 Jun 2023 15:32:13 +0100 Subject: [PATCH] Fix the cat implementation + more testing. --- src/tensor.rs | 35 +++++++++++++++++++++++++++++------ tests/tensor_tests.rs | 12 ++++++------ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index 6e9a547a..53a8b1f9 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -817,11 +817,34 @@ impl Tensor { shape: args[0].shape().clone(), }); } + if dim == 0 { + Self::cat0(args) + } else { + // TODO: Avoid these transpositions and have an implementation that works + // for dim != 0... + let args: Vec = args + .iter() + .map(|a| a.transpose(0, dim)) + .collect::>>()?; + let args: Vec<&Tensor> = args.iter().collect(); + let cat = Self::cat0(&args)?; + cat.transpose(0, dim) + } + } + + pub fn cat0(args: &[&Self]) -> Result { + if args.is_empty() { + return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); + } + if args.len() == 1 { + return Ok(args[0].clone()); + } + let rank = args[0].rank(); let device = args[0].device(); let dtype = args[0].dtype(); let first_dims = args[0].shape().dims(); let mut cat_dims = first_dims.to_vec(); - cat_dims[dim] = 0; + cat_dims[0] = 0; let mut offsets = vec![0usize]; for (arg_idx, arg) in args.iter().enumerate() { if arg.dtype() != dtype { @@ -848,10 +871,10 @@ impl Tensor { .zip(arg.shape().dims().iter()) .enumerate() { - if dim == dim_idx { - cat_dims[dim] += v2; + if dim_idx == 0 { + cat_dims[0] += v2; } - if dim != dim_idx && v1 != v2 { + if dim_idx != 0 && v1 != v2 { // TODO: It would probably be good to have a nicer error message here, i.e. // mention the problematic dimension and the values. mismatch = true; @@ -859,7 +882,7 @@ impl Tensor { } if mismatch { return Err(Error::ShapeMismatchCat { - dim, + dim: 0, // TODO: not the appropriate error message first_shape: args[0].shape().clone(), n: arg_idx + 1, nth_shape: arg.shape().clone(), @@ -871,7 +894,7 @@ impl Tensor { let shape = Shape::from(cat_dims); let op = if args.iter().any(|arg| arg.track_op()) { let args: Vec = args.iter().map(|&arg| arg.clone()).collect(); - Some(Op::Cat(args, dim)) + Some(Op::Cat(args, 0)) } else { None }; diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index cb52ba7c..8269ace1 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -210,18 +210,18 @@ fn cat() -> Result<()> { .t()? .to_vec2::()?, [ - [3.0, 4.0, 5.0, 5.0, 5.0], - [2.0, 1.0, 2.0, 7.0, 8.0], - [1.0, 1.0, 5.0, 5.0, 5.0], - [7.0, 8.0, 2.0, 1.0, 2.0] + [3.0, 1.0, 4.0, 1.0, 5.0], + [2.0, 7.0, 1.0, 8.0, 2.0], + [5.0, 5.0, 5.0, 5.0, 5.0], + [2.0, 7.0, 1.0, 8.0, 2.0] ] ); // TODO: This is not the expected answer, to be fixed! assert_eq!( Tensor::cat(&[&t1, &t2], 1)?.to_vec2::()?, [ - [3.0, 1.0, 4.0, 1.0, 5.0, 2.0, 7.0, 1.0, 8.0, 2.0], - [5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 7.0, 1.0, 8.0, 2.0] + [3.0, 1.0, 4.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0], + [2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0] ] ); Ok(())