mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add squeeze/unsqueeze/stack.
This commit is contained in:
@ -924,6 +924,36 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn squeeze(&self, index: usize) -> Result<Self> {
|
||||||
|
// The PyTorch semantics are to return the same tensor if the target dimension
|
||||||
|
// does not have a size of 1.
|
||||||
|
let dims = self.dims();
|
||||||
|
if dims[index] == 1 {
|
||||||
|
let mut dims = dims.to_vec();
|
||||||
|
dims.remove(index);
|
||||||
|
self.reshape(dims)
|
||||||
|
} else {
|
||||||
|
Ok(self.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unsqueeze(&self, index: usize) -> Result<Self> {
|
||||||
|
let mut dims = self.dims().to_vec();
|
||||||
|
dims.insert(index, 1);
|
||||||
|
self.reshape(dims)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn stack<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||||
|
if args.is_empty() {
|
||||||
|
return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" });
|
||||||
|
}
|
||||||
|
let args = args
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.as_ref().unsqueeze(dim))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
Self::cat(&args, dim)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn cat<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
pub fn cat<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||||
if args.is_empty() {
|
if args.is_empty() {
|
||||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||||
|
Reference in New Issue
Block a user