mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Squeeze/unsqueeze/reshape
This commit is contained in:
@ -995,6 +995,18 @@ impl Tensor {
|
||||
/// original tensor is the same.
|
||||
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
|
||||
/// a new storage and copies the data over, the returned tensor is always contiguous.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle::{Tensor, DType, Device, D};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = a.reshape((1, 6))?;
|
||||
/// assert_eq!(c.shape().dims(), &[1, 6]);
|
||||
///
|
||||
/// let c = a.reshape((3, 2))?;
|
||||
/// assert_eq!(c.shape().dims(), &[3, 2]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
||||
let shape = shape.into();
|
||||
if shape.elem_count() != self.elem_count() {
|
||||
@ -1026,6 +1038,19 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes an extra rank on the tensor with dimension 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle::{Tensor, DType, Device, D};
|
||||
/// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = a.squeeze(2)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 3]);
|
||||
///
|
||||
/// let c = a.squeeze(D::Minus1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 3]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
// The PyTorch semantics are to return the same tensor if the target dimension
|
||||
// does not have a size of 1.
|
||||
@ -1040,8 +1065,23 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unsqueeze(&self, dim: usize) -> Result<Self> {
|
||||
/// Creates an extra rank on the tensor with dimension 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle::{Tensor, DType, Device, D};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = a.unsqueeze(0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[1, 2, 3]);
|
||||
///
|
||||
/// let c = a.unsqueeze(D::Minus1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 3, 1]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let mut dims = self.dims().to_vec();
|
||||
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
|
||||
// Cannot panic because to_index_plus_one already checks dimensions
|
||||
dims.insert(dim, 1);
|
||||
self.reshape(dims)
|
||||
}
|
||||
|
Reference in New Issue
Block a user