Reshape can now return a view.

This commit is contained in:
laurent
2023-06-24 07:14:09 +01:00
parent 47f9c48e7c
commit 3deacba5f9

View File

@ -648,8 +648,9 @@ impl Tensor {
// TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same. This uses a new storage and copies the data over, the returned
/// tensor is always contiguous.
/// original tensor is the same.
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
/// a new storage and copies the data over, the returned tensor is always contiguous.
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
let shape = shape.into();
if shape.elem_count() != self.elem_count() {
@ -659,15 +660,28 @@ impl Tensor {
op: "reshape",
});
}
let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
let op = if self.track_op() {
Some(Op::Reshape(self.clone()))
} else {
None
};
Ok(from_storage(storage, shape, op, false))
if self.is_contiguous() {
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
shape,
stride,
op,
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
} else {
let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
Ok(from_storage(storage, shape, op, false))
}
}
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {