Avoid keeping track of the copy ops when not necessary. (#239)

This commit is contained in:
Laurent Mazare
2023-07-25 10:06:01 +01:00
committed by GitHub
parent 944d70bd9a
commit be9c26180c

View File

@ -1434,11 +1434,16 @@ impl Tensor {
/// Compared to clone, this copies the actual storage but may fail because of running out of /// Compared to clone, this copies the actual storage but may fail because of running out of
/// memory. /// memory.
pub fn copy(&self) -> Result<Tensor> { pub fn copy(&self) -> Result<Tensor> {
let op = if self.track_op() {
Some(Op::Copy(self.clone()))
} else {
None
};
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)), storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
layout: self.layout.clone(), layout: self.layout.clone(),
op: Some(Op::Copy(self.clone())), op,
is_variable: false, is_variable: false,
dtype: self.dtype, dtype: self.dtype,
device: self.device.clone(), device: self.device.clone(),
@ -1571,12 +1576,12 @@ impl Tensor {
let mut storage = self.device().zeros(shape, self.dtype())?; let mut storage = self.device().zeros(shape, self.dtype())?;
self.storage() self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?; .copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage( let op = if self.track_op() {
storage, Some(Op::Copy(self.clone()))
shape.clone(), } else {
Some(Op::Copy(self.clone())), None
false, };
)) Ok(from_storage(storage, shape.clone(), op, false))
} }
} }