From 3deacba5f94cf5e87e9cf6eb45647d8608c1e906 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 24 Jun 2023 07:14:09 +0100 Subject: [PATCH] Reshape can now return a view. --- src/tensor.rs | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index 508dee49..741bef0d 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -648,8 +648,9 @@ impl Tensor { // TODO: Do we want to allow target shape using -1 on some dimensions? /// Reshape returns a tensor with the target shape provided that the number of elements of the - /// original tensor is the same. This uses a new storage and copies the data over, the returned - /// tensor is always contiguous. + /// original tensor is the same. + /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses + /// a new storage and copies the data over, the returned tensor is always contiguous. pub fn reshape>(&self, shape: S) -> Result { let shape = shape.into(); if shape.elem_count() != self.elem_count() { @@ -659,15 +660,28 @@ impl Tensor { op: "reshape", }); } - let mut storage = self.device().zeros(&shape, self.dtype())?; - self.storage - .copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?; let op = if self.track_op() { Some(Op::Reshape(self.clone())) } else { None }; - Ok(from_storage(storage, shape, op, false)) + if self.is_contiguous() { + let stride = shape.stride_contiguous(); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + shape, + stride, + op, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) + } else { + let mut storage = self.device().zeros(&shape, self.dtype())?; + self.storage + .copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?; + Ok(from_storage(storage, shape, op, false)) + } } pub fn cat(args: &[Self], dim: usize) -> Result {