Add StorageRef. (#2113)

* Add the storage-ref bits.

* Add the metal implementation.
This commit is contained in:
Laurent Mazare
2024-04-23 13:23:27 +02:00
committed by GitHub
parent b2e816752b
commit 8a05743a21
10 changed files with 108 additions and 5 deletions

View File

@ -306,6 +306,20 @@ impl Device {
}
}
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
Device::Cuda(device) => {
let storage = device.storage_from_slice(data)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.storage_from_slice(data)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),