Add the copy method.

This commit is contained in:
laurent
2023-06-23 08:12:52 +01:00
parent 552276749a
commit fc41ccb5bb
2 changed files with 15 additions and 0 deletions

View File

@ -1,5 +1,6 @@
use crate::Tensor;
#[derive(Clone)]
pub(crate) enum Op {
Add(Tensor, Tensor),
Mul(Tensor, Tensor),

View File

@ -490,6 +490,20 @@ impl Tensor {
self.shape.is_contiguous(&self.stride)
}
/// Compared to clone, this copies the actual storage but may fail because of running out of
/// memory.
pub fn copy(&self) -> Result<Tensor> {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.try_clone()?,
shape: self.shape.clone(),
stride: self.stride.clone(),
op: self.op.clone(),
is_variable: self.is_variable,
};
Ok(Tensor(Arc::new(tensor_)))
}
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
/// argument.