Take references as input for Tensor::cat.

This commit is contained in:
laurent
2023-06-25 13:03:05 +01:00
parent 5e03a1bc29
commit a9c113248a
2 changed files with 4 additions and 2 deletions

View File

@ -802,7 +802,7 @@ impl Tensor {
} }
} }
pub fn cat(args: &[Self], dim: usize) -> Result<Self> { pub fn cat(args: &[&Self], dim: usize) -> Result<Self> {
if args.is_empty() { if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
} }
@ -867,7 +867,8 @@ impl Tensor {
} }
let shape = Shape::from(cat_dims); let shape = Shape::from(cat_dims);
let op = if args.iter().any(|arg| arg.track_op()) { let op = if args.iter().any(|arg| arg.track_op()) {
Some(Op::Cat(args.to_vec(), dim)) let args: Vec<Tensor> = args.iter().map(|&arg| arg.clone()).collect();
Some(Op::Cat(args, dim))
} else { } else {
None None
}; };

View File

@ -120,6 +120,7 @@ fn sum() -> Result<()> {
tensor.sum(&[0])?.to_vec3::<u32>()?, tensor.sum(&[0])?.to_vec3::<u32>()?,
&[[[5, 2, 11], [9, 7, 17]]], &[[[5, 2, 11], [9, 7, 17]]],
); );
assert_eq!(tensor.sum(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
assert_eq!( assert_eq!(
tensor.t()?.sum(&[1])?.t()?.to_vec3::<u32>()?, tensor.t()?.sum(&[1])?.t()?.to_vec3::<u32>()?,
&[[[8], [15]], [[10], [18]]] &[[[8], [15]], [[10], [18]]]