mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Transfer tensors between devices.
This commit is contained in:
@ -4,6 +4,8 @@ use gemm::{gemm, Parallelism};
|
|||||||
|
|
||||||
// TODO: Think about whether we would be better off with a dtype and
|
// TODO: Think about whether we would be better off with a dtype and
|
||||||
// a buffer as an owned slice of bytes.
|
// a buffer as an owned slice of bytes.
|
||||||
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||||
|
// intercept the oom errors to avoid panicking and provide a proper error.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum CpuStorage {
|
pub enum CpuStorage {
|
||||||
F32(Vec<f32>),
|
F32(Vec<f32>),
|
||||||
|
@ -28,8 +28,22 @@ pub enum CudaError {
|
|||||||
|
|
||||||
type Result<T> = std::result::Result<T, CudaError>;
|
type Result<T> = std::result::Result<T, CudaError>;
|
||||||
|
|
||||||
|
/// Unique identifier for cuda devices.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
|
pub(crate) struct DeviceId(usize);
|
||||||
|
|
||||||
|
impl DeviceId {
|
||||||
|
fn new() -> Self {
|
||||||
|
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||||
|
use std::sync::atomic;
|
||||||
|
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||||
|
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CudaDevice {
|
pub struct CudaDevice {
|
||||||
|
id: DeviceId,
|
||||||
device: Arc<cudarc::driver::CudaDevice>,
|
device: Arc<cudarc::driver::CudaDevice>,
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
blas: Arc<cudarc::cublas::CudaBlas>,
|
blas: Arc<cudarc::cublas::CudaBlas>,
|
||||||
@ -48,11 +62,16 @@ impl CudaDevice {
|
|||||||
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
||||||
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
|
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
id: DeviceId::new(),
|
||||||
device,
|
device,
|
||||||
blas: Arc::new(blas),
|
blas: Arc::new(blas),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn same_id(&self, rhs: &Self) -> bool {
|
||||||
|
self.id == rhs.id
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn ordinal(&self) -> usize {
|
pub(crate) fn ordinal(&self) -> usize {
|
||||||
self.device.ordinal()
|
self.device.ordinal()
|
||||||
}
|
}
|
||||||
|
@ -66,6 +66,14 @@ impl Device {
|
|||||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn same_id(&self, rhs: &Self) -> bool {
|
||||||
|
match (self, rhs) {
|
||||||
|
(Self::Cpu, Self::Cpu) => true,
|
||||||
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_id(rhs),
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn location(&self) -> DeviceLocation {
|
pub fn location(&self) -> DeviceLocation {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu => DeviceLocation::Cpu,
|
Self::Cpu => DeviceLocation::Cpu,
|
||||||
|
@ -17,6 +17,10 @@ impl CudaDevice {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn same_id(&self, _: &Self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn ordinal(&self) -> usize {
|
pub(crate) fn ordinal(&self) -> usize {
|
||||||
fail!()
|
fail!()
|
||||||
}
|
}
|
||||||
|
@ -504,6 +504,36 @@ impl Tensor {
|
|||||||
Ok(Tensor(Arc::new(tensor_)))
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||||
|
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
|
||||||
|
if self.device().same_id(device) {
|
||||||
|
Ok(self.clone())
|
||||||
|
} else {
|
||||||
|
let storage = match (&self.storage, device) {
|
||||||
|
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||||
|
Storage::Cuda(cuda.cuda_from_cpu_storage(storage)?)
|
||||||
|
}
|
||||||
|
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||||
|
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||||
|
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||||
|
// are the same.
|
||||||
|
let cpu_storage = storage.to_cpu_storage()?;
|
||||||
|
Storage::Cuda(cuda.cuda_from_cpu_storage(&cpu_storage)?)
|
||||||
|
}
|
||||||
|
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||||
|
};
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage,
|
||||||
|
shape: self.shape.clone(),
|
||||||
|
stride: self.stride.clone(),
|
||||||
|
op: None, // TODO: Have a proper op here.
|
||||||
|
is_variable: self.is_variable,
|
||||||
|
};
|
||||||
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||||
/// argument.
|
/// argument.
|
||||||
|
Reference in New Issue
Block a user