Add the reshape method and operation (without grad for now).

This commit is contained in:
laurent
2023-06-23 10:51:05 +01:00
parent c4c6167949
commit 79e4b29c2f
2 changed files with 30 additions and 1 deletions

View File

@ -595,6 +595,36 @@ impl Tensor {
}
}
// TODO: Do we want to allow target shape using -1 on some dimensions?
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
let shape = shape.into();
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: shape,
op: "reshape",
});
}
let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage
.copy_strided_src(&mut storage, &shape, &self.stride, 0)?;
let op = if self.track_op() {
Some(Op::Reshape(self.clone()))
} else {
None
};
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape,
stride,
op,
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
}
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });