mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Take references as input for Tensor::cat.
This commit is contained in:
@ -802,7 +802,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
|
pub fn cat(args: &[&Self], dim: usize) -> Result<Self> {
|
||||||
if args.is_empty() {
|
if args.is_empty() {
|
||||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||||
}
|
}
|
||||||
@ -867,7 +867,8 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
let shape = Shape::from(cat_dims);
|
let shape = Shape::from(cat_dims);
|
||||||
let op = if args.iter().any(|arg| arg.track_op()) {
|
let op = if args.iter().any(|arg| arg.track_op()) {
|
||||||
Some(Op::Cat(args.to_vec(), dim))
|
let args: Vec<Tensor> = args.iter().map(|&arg| arg.clone()).collect();
|
||||||
|
Some(Op::Cat(args, dim))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
@ -120,6 +120,7 @@ fn sum() -> Result<()> {
|
|||||||
tensor.sum(&[0])?.to_vec3::<u32>()?,
|
tensor.sum(&[0])?.to_vec3::<u32>()?,
|
||||||
&[[[5, 2, 11], [9, 7, 17]]],
|
&[[[5, 2, 11], [9, 7, 17]]],
|
||||||
);
|
);
|
||||||
|
assert_eq!(tensor.sum(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.t()?.sum(&[1])?.t()?.to_vec3::<u32>()?,
|
tensor.t()?.sum(&[1])?.t()?.to_vec3::<u32>()?,
|
||||||
&[[[8], [15]], [[10], [18]]]
|
&[[[8], [15]], [[10], [18]]]
|
||||||
|
Reference in New Issue
Block a user