mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Modular backends (#138)
* Add some trait to formalize backends. * Use the generic backend trait.
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
|
||||
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||
@ -85,10 +86,10 @@ impl Device {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn same_id(&self, rhs: &Self) -> bool {
|
||||
pub fn same_device(&self, rhs: &Self) -> bool {
|
||||
match (self, rhs) {
|
||||
(Self::Cpu, Self::Cpu) => true,
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_id(rhs),
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@ -96,9 +97,7 @@ impl Device {
|
||||
pub fn location(&self) -> DeviceLocation {
|
||||
match self {
|
||||
Self::Cpu => DeviceLocation::Cpu,
|
||||
Self::Cuda(device) => DeviceLocation::Cuda {
|
||||
gpu_id: device.ordinal(),
|
||||
},
|
||||
Self::Cuda(device) => device.location(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -178,7 +177,7 @@ impl Device {
|
||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||
Device::Cuda(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.cuda_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
@ -189,7 +188,7 @@ impl Device {
|
||||
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
||||
Device::Cuda(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.cuda_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user