Shape with holes (#770)

* Shape with holes.

* rustfmt.
This commit is contained in:
Laurent Mazare
2023-09-08 08:38:13 +01:00
committed by GitHub
parent cfcbec9fc7
commit 0e250aee4f
3 changed files with 184 additions and 6 deletions

View File

@ -1685,12 +1685,15 @@ impl Tensor {
Ok(from_storage(storage, shape, BackpropOp::none(), true))
}
// TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same.
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
/// a new storage and copies the data over, the returned tensor is always contiguous.
///
/// The shape can be specified using a tuple of `usize` and at most one `()` in which case
/// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
/// as to match the number of elements in the tensor.
///
/// ```rust
/// # use candle_core::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
@ -1700,10 +1703,14 @@ impl Tensor {
///
/// let c = a.reshape((3, 2))?;
/// assert_eq!(c.shape().dims(), &[3, 2]);
///
/// let c = a.reshape((2, (), 1))?;
/// assert_eq!(c.shape().dims(), &[2, 3, 1]);
///
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
let shape = shape.into();
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
let shape = s.into_shape(self.elem_count())?;
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),