Add squeeze/unsqueeze/stack.

This commit is contained in:
laurent
2023-06-27 19:32:00 +01:00
parent 1d504cc6b3
commit 934655a60d

View File

@ -924,6 +924,36 @@ impl Tensor {
}
}
pub fn squeeze(&self, index: usize) -> Result<Self> {
// The PyTorch semantics are to return the same tensor if the target dimension
// does not have a size of 1.
let dims = self.dims();
if dims[index] == 1 {
let mut dims = dims.to_vec();
dims.remove(index);
self.reshape(dims)
} else {
Ok(self.clone())
}
}
pub fn unsqueeze(&self, index: usize) -> Result<Self> {
let mut dims = self.dims().to_vec();
dims.insert(index, 1);
self.reshape(dims)
}
pub fn stack<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" });
}
let args = args
.iter()
.map(|t| t.as_ref().unsqueeze(dim))
.collect::<Result<Vec<_>>>()?;
Self::cat(&args, dim)
}
pub fn cat<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });