Add the alloc_uninit function. (#1901)

* Add the alloc_uninit function.

* Dummy metal fix.

* Lazy initialization.
This commit is contained in:
Laurent Mazare
2024-03-22 07:25:23 +01:00
committed by GitHub
parent a00e24d752
commit 6708870e63
9 changed files with 154 additions and 16 deletions

View File

@ -1349,7 +1349,7 @@ impl Tensor {
}
.bt())?
}
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let offset = start * src.dims()[1..].iter().product::<usize>();
@ -1999,7 +1999,7 @@ impl Tensor {
Ok(self.clone())
} else {
let shape = self.shape();
let mut storage = self.device().zeros(shape, self.dtype())?;
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let op = BackpropOp::new1(self, Op::Copy);
@ -2011,7 +2011,7 @@ impl Tensor {
/// copied.
pub(crate) fn make_var(&self) -> Result<Tensor> {
let shape = self.shape().clone();
let mut storage = self.device().zeros(&shape, self.dtype())?;
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, BackpropOp::none(), true))
@ -2064,7 +2064,7 @@ impl Tensor {
};
Ok(Tensor(Arc::new(tensor_)))
} else {
let mut storage = self.device().zeros(&shape, self.dtype())?;
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, op, false))