diff --git a/src/op.rs b/src/op.rs index 22b28299..45fe97a4 100644 --- a/src/op.rs +++ b/src/op.rs @@ -17,7 +17,6 @@ pub(crate) enum Op { add: f64, }, Neg(Tensor), - #[allow(dead_code)] Reshape(Tensor), Sqr(Tensor), Sqrt(Tensor), diff --git a/src/tensor.rs b/src/tensor.rs index 99cd6064..e9a41937 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -595,6 +595,36 @@ impl Tensor { } } + // TODO: Do we want to allow target shape using -1 on some dimensions? + pub fn reshape>(&self, shape: S) -> Result { + let shape = shape.into(); + if shape.elem_count() != self.elem_count() { + return Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: shape, + op: "reshape", + }); + } + let mut storage = self.device().zeros(&shape, self.dtype())?; + self.storage + .copy_strided_src(&mut storage, &shape, &self.stride, 0)?; + let op = if self.track_op() { + Some(Op::Reshape(self.clone())) + } else { + None + }; + let stride = shape.stride_contiguous(); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape, + stride, + op, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) + } + pub fn cat(args: &[Self], dim: usize) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });