From a9c113248aa84df3716d564f4bbe1fd42dab94f6 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 25 Jun 2023 13:03:05 +0100 Subject: [PATCH] Take references as input for Tensor::cat. --- src/tensor.rs | 5 +++-- tests/tensor_tests.rs | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index c206ae30..83598962 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -802,7 +802,7 @@ impl Tensor { } } - pub fn cat(args: &[Self], dim: usize) -> Result { + pub fn cat(args: &[&Self], dim: usize) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); } @@ -867,7 +867,8 @@ impl Tensor { } let shape = Shape::from(cat_dims); let op = if args.iter().any(|arg| arg.track_op()) { - Some(Op::Cat(args.to_vec(), dim)) + let args: Vec = args.iter().map(|&arg| arg.clone()).collect(); + Some(Op::Cat(args, dim)) } else { None }; diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index e61d58ed..b4f8bdd0 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -120,6 +120,7 @@ fn sum() -> Result<()> { tensor.sum(&[0])?.to_vec3::()?, &[[[5, 2, 11], [9, 7, 17]]], ); + assert_eq!(tensor.sum(&[0, 2, 1])?.to_vec3::()?, &[[[51]]],); assert_eq!( tensor.t()?.sum(&[1])?.t()?.to_vec3::()?, &[[[8], [15]], [[10], [18]]]