mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add squeeze/unsqueeze/stack.
This commit is contained in:
@ -924,6 +924,36 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn squeeze(&self, index: usize) -> Result<Self> {
|
||||
// The PyTorch semantics are to return the same tensor if the target dimension
|
||||
// does not have a size of 1.
|
||||
let dims = self.dims();
|
||||
if dims[index] == 1 {
|
||||
let mut dims = dims.to_vec();
|
||||
dims.remove(index);
|
||||
self.reshape(dims)
|
||||
} else {
|
||||
Ok(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unsqueeze(&self, index: usize) -> Result<Self> {
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims.insert(index, 1);
|
||||
self.reshape(dims)
|
||||
}
|
||||
|
||||
pub fn stack<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" });
|
||||
}
|
||||
let args = args
|
||||
.iter()
|
||||
.map(|t| t.as_ref().unsqueeze(dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Self::cat(&args, dim)
|
||||
}
|
||||
|
||||
pub fn cat<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||
|
Reference in New Issue
Block a user