Add the alloc_uninit function. (#1901)

* Add the alloc_uninit function.

* Dummy metal fix.

* Lazy initialization.
This commit is contained in:
Laurent Mazare
2024-03-22 07:25:23 +01:00
committed by GitHub
parent a00e24d752
commit 6708870e63
9 changed files with 154 additions and 16 deletions

View File

@ -289,6 +289,23 @@ impl Device {
}
}
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),