From 934655a60d9bbf773d5b02c6aff2a8c26edd9be8 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 19:32:00 +0100 Subject: [PATCH] Add squeeze/unsqueeze/stack. --- candle-core/src/tensor.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index fc67ae94..b64f63e1 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -924,6 +924,36 @@ impl Tensor { } } + pub fn squeeze(&self, index: usize) -> Result { + // The PyTorch semantics are to return the same tensor if the target dimension + // does not have a size of 1. + let dims = self.dims(); + if dims[index] == 1 { + let mut dims = dims.to_vec(); + dims.remove(index); + self.reshape(dims) + } else { + Ok(self.clone()) + } + } + + pub fn unsqueeze(&self, index: usize) -> Result { + let mut dims = self.dims().to_vec(); + dims.insert(index, 1); + self.reshape(dims) + } + + pub fn stack>(args: &[A], dim: usize) -> Result { + if args.is_empty() { + return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }); + } + let args = args + .iter() + .map(|t| t.as_ref().unsqueeze(dim)) + .collect::>>()?; + Self::cat(&args, dim) + } + pub fn cat>(args: &[A], dim: usize) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });