Merge pull request #124 from LaurentMazare/new_doc

Squeeze/unsqueeze/reshape
This commit is contained in:
Nicolas Patry
2023-07-10 20:43:23 +02:00
committed by GitHub

View File

@ -1048,6 +1048,18 @@ impl Tensor {
/// original tensor is the same.
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
/// a new storage and copies the data over, the returned tensor is always contiguous.
///
/// ```rust
/// # use candle::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = a.reshape((1, 6))?;
/// assert_eq!(c.shape().dims(), &[1, 6]);
///
/// let c = a.reshape((3, 2))?;
/// assert_eq!(c.shape().dims(), &[3, 2]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
let shape = shape.into();
if shape.elem_count() != self.elem_count() {
@ -1079,6 +1091,19 @@ impl Tensor {
}
}
/// Removes an extra rank on the tensor with dimension 1.
///
/// ```rust
/// # use candle::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;
///
/// let c = a.squeeze(2)?;
/// assert_eq!(c.shape().dims(), &[2, 3]);
///
/// let c = a.squeeze(D::Minus1)?;
/// assert_eq!(c.shape().dims(), &[2, 3]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
// The PyTorch semantics are to return the same tensor if the target dimension
// does not have a size of 1.
@ -1093,8 +1118,23 @@ impl Tensor {
}
}
pub fn unsqueeze(&self, dim: usize) -> Result<Self> {
/// Creates an extra rank on the tensor with dimension 1.
///
/// ```rust
/// # use candle::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = a.unsqueeze(0)?;
/// assert_eq!(c.shape().dims(), &[1, 2, 3]);
///
/// let c = a.unsqueeze(D::Minus1)?;
/// assert_eq!(c.shape().dims(), &[2, 3, 1]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
let mut dims = self.dims().to_vec();
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
// Cannot panic because to_index_plus_one already checks dimensions
dims.insert(dim, 1);
self.reshape(dims)
}