mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Squeeze/unsqueeze/reshape
This commit is contained in:
@ -995,6 +995,18 @@ impl Tensor {
|
|||||||
/// original tensor is the same.
|
/// original tensor is the same.
|
||||||
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
|
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
|
||||||
/// a new storage and copies the data over, the returned tensor is always contiguous.
|
/// a new storage and copies the data over, the returned tensor is always contiguous.
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use candle::{Tensor, DType, Device, D};
|
||||||
|
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let c = a.reshape((1, 6))?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[1, 6]);
|
||||||
|
///
|
||||||
|
/// let c = a.reshape((3, 2))?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[3, 2]);
|
||||||
|
/// # Ok::<(), candle::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
if shape.elem_count() != self.elem_count() {
|
if shape.elem_count() != self.elem_count() {
|
||||||
@ -1026,6 +1038,19 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Removes an extra rank on the tensor with dimension 1.
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use candle::{Tensor, DType, Device, D};
|
||||||
|
/// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let c = a.squeeze(2)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2, 3]);
|
||||||
|
///
|
||||||
|
/// let c = a.squeeze(D::Minus1)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2, 3]);
|
||||||
|
/// # Ok::<(), candle::Error>(())
|
||||||
|
/// ```
|
||||||
pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
// The PyTorch semantics are to return the same tensor if the target dimension
|
// The PyTorch semantics are to return the same tensor if the target dimension
|
||||||
// does not have a size of 1.
|
// does not have a size of 1.
|
||||||
@ -1040,8 +1065,23 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unsqueeze(&self, dim: usize) -> Result<Self> {
|
/// Creates an extra rank on the tensor with dimension 1.
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use candle::{Tensor, DType, Device, D};
|
||||||
|
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let c = a.unsqueeze(0)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[1, 2, 3]);
|
||||||
|
///
|
||||||
|
/// let c = a.unsqueeze(D::Minus1)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2, 3, 1]);
|
||||||
|
/// # Ok::<(), candle::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
let mut dims = self.dims().to_vec();
|
let mut dims = self.dims().to_vec();
|
||||||
|
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
|
||||||
|
// Cannot panic because to_index_plus_one already checks dimensions
|
||||||
dims.insert(dim, 1);
|
dims.insert(dim, 1);
|
||||||
self.reshape(dims)
|
self.reshape(dims)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user