mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Modular backends (#138)
* Add some trait to formalize backends. * Use the generic backend trait.
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::shape::Dim;
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
@ -963,19 +964,19 @@ impl Tensor {
|
||||
|
||||
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
|
||||
if self.device().same_id(device) {
|
||||
if self.device().same_device(device) {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let storage = match (self.storage.as_ref(), device) {
|
||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||
Storage::Cuda(cuda.cuda_from_cpu_storage(storage)?)
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||
// are the same.
|
||||
let cpu_storage = storage.to_cpu_storage()?;
|
||||
Storage::Cuda(cuda.cuda_from_cpu_storage(&cpu_storage)?)
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||
};
|
||||
|
Reference in New Issue
Block a user