Modular backends (#138)

* Add some trait to formalize backends.

* Use the generic backend trait.
This commit is contained in:
Laurent Mazare
2023-07-11 11:17:02 +01:00
committed by GitHub
parent 674eb35e10
commit 64264d97c1
9 changed files with 457 additions and 373 deletions

View File

@ -1,3 +1,4 @@
use crate::backend::BackendDevice;
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
@ -85,10 +86,10 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn same_id(&self, rhs: &Self) -> bool {
pub fn same_device(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_id(rhs),
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
_ => false,
}
}
@ -96,9 +97,7 @@ impl Device {
pub fn location(&self) -> DeviceLocation {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => DeviceLocation::Cuda {
gpu_id: device.ordinal(),
},
Self::Cuda(device) => device.location(),
}
}
@ -178,7 +177,7 @@ impl Device {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
let storage = array.to_cpu_storage();
let storage = device.cuda_from_cpu_storage(&storage)?;
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
}
@ -189,7 +188,7 @@ impl Device {
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
Device::Cuda(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.cuda_from_cpu_storage(&storage)?;
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
}