Squeeze/unsqueeze/reshape

This commit is contained in:
Nicolas Patry
2023-07-10 16:40:25 +02:00
parent 221b1aff65
commit e01d099b71

View File

@ -995,6 +995,18 @@ impl Tensor {
/// original tensor is the same.
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
/// a new storage and copies the data over, the returned tensor is always contiguous.
///
/// ```rust
/// # use candle::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = a.reshape((1, 6))?;
/// assert_eq!(c.shape().dims(), &[1, 6]);
///
/// let c = a.reshape((3, 2))?;
/// assert_eq!(c.shape().dims(), &[3, 2]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
let shape = shape.into();
if shape.elem_count() != self.elem_count() {
@ -1026,6 +1038,19 @@ impl Tensor {
}
}
/// Removes an extra rank on the tensor with dimension 1.
///
/// ```rust
/// # use candle::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;
///
/// let c = a.squeeze(2)?;
/// assert_eq!(c.shape().dims(), &[2, 3]);
///
/// let c = a.squeeze(D::Minus1)?;
/// assert_eq!(c.shape().dims(), &[2, 3]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
// The PyTorch semantics are to return the same tensor if the target dimension
// does not have a size of 1.
@ -1040,8 +1065,23 @@ impl Tensor {
}
}
pub fn unsqueeze(&self, dim: usize) -> Result<Self> {
/// Creates an extra rank on the tensor with dimension 1.
///
/// ```rust
/// # use candle::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = a.unsqueeze(0)?;
/// assert_eq!(c.shape().dims(), &[1, 2, 3]);
///
/// let c = a.unsqueeze(D::Minus1)?;
/// assert_eq!(c.shape().dims(), &[2, 3, 1]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
let mut dims = self.dims().to_vec();
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
// Cannot panic because to_index_plus_one already checks dimensions
dims.insert(dim, 1);
self.reshape(dims)
}