mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Support dim indexes in cat.
This commit is contained in:
@ -970,10 +970,11 @@ impl Tensor {
|
||||
self.reshape(dims)
|
||||
}
|
||||
|
||||
pub fn stack<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||
pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" });
|
||||
}
|
||||
let dim = dim.to_index(args[0].as_ref().shape(), "stack")?;
|
||||
let args = args
|
||||
.iter()
|
||||
.map(|t| t.as_ref().unsqueeze(dim))
|
||||
@ -981,7 +982,7 @@ impl Tensor {
|
||||
Self::cat(&args, dim)
|
||||
}
|
||||
|
||||
pub fn cat<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||
}
|
||||
@ -989,6 +990,7 @@ impl Tensor {
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
|
Reference in New Issue
Block a user