Transfer tensors between devices.

This commit is contained in:
laurent
2023-06-23 08:35:22 +01:00
parent fc41ccb5bb
commit 3b550a56dc
5 changed files with 63 additions and 0 deletions

View File

@ -504,6 +504,36 @@ impl Tensor {
Ok(Tensor(Arc::new(tensor_)))
}
/// If the target device is the same as the tensor device, only a shallow copy is performed.
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
if self.device().same_id(device) {
Ok(self.clone())
} else {
let storage = match (&self.storage, device) {
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.cuda_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.
let cpu_storage = storage.to_cpu_storage()?;
Storage::Cuda(cuda.cuda_from_cpu_storage(&cpu_storage)?)
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape: self.shape.clone(),
stride: self.stride.clone(),
op: None, // TODO: Have a proper op here.
is_variable: self.is_variable,
};
Ok(Tensor(Arc::new(tensor_)))
}
}
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
/// argument.