Add the alloc_uninit function. (#1901)

* Add the alloc_uninit function.

* Dummy metal fix.

* Lazy initialization.
This commit is contained in:
Laurent Mazare
2024-03-22 07:25:23 +01:00
committed by GitHub
parent a00e24d752
commit 6708870e63
9 changed files with 154 additions and 16 deletions

View File

@ -141,7 +141,7 @@ impl Tensor {
}
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
let mut storage = device.zeros(&shape, dtype)?;
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage()
@ -215,7 +215,7 @@ impl Tensor {
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
let mut storage = device.zeros(&shape, dtype)?;
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
let mut dst_o = 0;
for arg in args.iter() {
let arg = arg.as_ref();