diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 9b0681e0..ccac82d6 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -995,6 +995,18 @@ impl Tensor { /// original tensor is the same. /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses /// a new storage and copies the data over, the returned tensor is always contiguous. + /// + /// ```rust + /// # use candle::{Tensor, DType, Device, D}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = a.reshape((1, 6))?; + /// assert_eq!(c.shape().dims(), &[1, 6]); + /// + /// let c = a.reshape((3, 2))?; + /// assert_eq!(c.shape().dims(), &[3, 2]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn reshape>(&self, shape: S) -> Result { let shape = shape.into(); if shape.elem_count() != self.elem_count() { @@ -1026,6 +1038,19 @@ impl Tensor { } } + /// Removes an extra rank on the tensor with dimension 1. + /// + /// ```rust + /// # use candle::{Tensor, DType, Device, D}; + /// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?; + /// + /// let c = a.squeeze(2)?; + /// assert_eq!(c.shape().dims(), &[2, 3]); + /// + /// let c = a.squeeze(D::Minus1)?; + /// assert_eq!(c.shape().dims(), &[2, 3]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn squeeze(&self, dim: D) -> Result { // The PyTorch semantics are to return the same tensor if the target dimension // does not have a size of 1. @@ -1040,8 +1065,23 @@ impl Tensor { } } - pub fn unsqueeze(&self, dim: usize) -> Result { + /// Creates an extra rank on the tensor with dimension 1. + /// + /// ```rust + /// # use candle::{Tensor, DType, Device, D}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = a.unsqueeze(0)?; + /// assert_eq!(c.shape().dims(), &[1, 2, 3]); + /// + /// let c = a.unsqueeze(D::Minus1)?; + /// assert_eq!(c.shape().dims(), &[2, 3, 1]); + /// # Ok::<(), candle::Error>(()) + /// ``` + pub fn unsqueeze(&self, dim: D) -> Result { let mut dims = self.dims().to_vec(); + let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?; + // Cannot panic because to_index_plus_one already checks dimensions dims.insert(dim, 1); self.reshape(dims) }