From 3b550a56dca4792bb7546257acf4339e02c6b801 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 23 Jun 2023 08:35:22 +0100 Subject: [PATCH] Transfer tensors between devices. --- src/cpu_backend.rs | 2 ++ src/cuda_backend.rs | 19 +++++++++++++++++++ src/device.rs | 8 ++++++++ src/dummy_cuda_backend.rs | 4 ++++ src/tensor.rs | 30 ++++++++++++++++++++++++++++++ 5 files changed, 63 insertions(+) diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 2a393f5f..1bd272d8 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -4,6 +4,8 @@ use gemm::{gemm, Parallelism}; // TODO: Think about whether we would be better off with a dtype and // a buffer as an owned slice of bytes. +// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + +// intercept the oom errors to avoid panicking and provide a proper error. #[derive(Debug, Clone)] pub enum CpuStorage { F32(Vec), diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 8704077c..7b6dd655 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -28,8 +28,22 @@ pub enum CudaError { type Result = std::result::Result; +/// Unique identifier for cuda devices. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub(crate) struct DeviceId(usize); + +impl DeviceId { + fn new() -> Self { + // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 + use std::sync::atomic; + static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); + Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) + } +} + #[derive(Debug, Clone)] pub struct CudaDevice { + id: DeviceId, device: Arc, #[allow(dead_code)] blas: Arc, @@ -48,11 +62,16 @@ impl CudaDevice { let device = cudarc::driver::CudaDevice::new(ordinal)?; let blas = cudarc::cublas::CudaBlas::new(device.clone())?; Ok(Self { + id: DeviceId::new(), device, blas: Arc::new(blas), }) } + pub(crate) fn same_id(&self, rhs: &Self) -> bool { + self.id == rhs.id + } + pub(crate) fn ordinal(&self) -> usize { self.device.ordinal() } diff --git a/src/device.rs b/src/device.rs index 8acb1338..62afd905 100644 --- a/src/device.rs +++ b/src/device.rs @@ -66,6 +66,14 @@ impl Device { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) } + pub fn same_id(&self, rhs: &Self) -> bool { + match (self, rhs) { + (Self::Cpu, Self::Cpu) => true, + (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_id(rhs), + _ => false, + } + } + pub fn location(&self) -> DeviceLocation { match self { Self::Cpu => DeviceLocation::Cpu, diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index f8669494..fbcfe758 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -17,6 +17,10 @@ impl CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn same_id(&self, _: &Self) -> bool { + true + } + pub(crate) fn ordinal(&self) -> usize { fail!() } diff --git a/src/tensor.rs b/src/tensor.rs index a0f6fa11..549dafbe 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -504,6 +504,36 @@ impl Tensor { Ok(Tensor(Arc::new(tensor_))) } + /// If the target device is the same as the tensor device, only a shallow copy is performed. + pub fn to_device(&self, device: &Device) -> Result { + if self.device().same_id(device) { + Ok(self.clone()) + } else { + let storage = match (&self.storage, device) { + (Storage::Cpu(storage), Device::Cuda(cuda)) => { + Storage::Cuda(cuda.cuda_from_cpu_storage(storage)?) + } + (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), + (Storage::Cuda(storage), Device::Cuda(cuda)) => { + // TODO: Avoid passing through the cpu storage here, especially if the gpu ids + // are the same. + let cpu_storage = storage.to_cpu_storage()?; + Storage::Cuda(cuda.cuda_from_cpu_storage(&cpu_storage)?) + } + (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()), + }; + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape: self.shape.clone(), + stride: self.stride.clone(), + op: None, // TODO: Have a proper op here. + is_variable: self.is_variable, + }; + Ok(Tensor(Arc::new(tensor_))) + } + } + /// Return all the nodes that lead to this value in a topologically sorted vec, the first /// elements having dependencies on the latter ones, e.g. the first element if any is the /// argument.