Support dim indexes in cat.

This commit is contained in:
laurent
2023-07-05 20:39:08 +01:00
parent fc2ffcc72b
commit e2bfbcb79c
2 changed files with 14 additions and 13 deletions

View File

@ -970,10 +970,11 @@ impl Tensor {
self.reshape(dims)
}
pub fn stack<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" });
}
let dim = dim.to_index(args[0].as_ref().shape(), "stack")?;
let args = args
.iter()
.map(|t| t.as_ref().unsqueeze(dim))
@ -981,7 +982,7 @@ impl Tensor {
Self::cat(&args, dim)
}
pub fn cat<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
}
@ -989,6 +990,7 @@ impl Tensor {
if args.len() == 1 {
return Ok(arg0.clone());
}
let dim = dim.to_index(arg0.shape(), "cat")?;
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}