mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Transfer tensors between devices.
This commit is contained in:
@ -28,8 +28,22 @@ pub enum CudaError {
|
||||
|
||||
type Result<T> = std::result::Result<T, CudaError>;
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
device: Arc<cudarc::driver::CudaDevice>,
|
||||
#[allow(dead_code)]
|
||||
blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
@ -48,11 +62,16 @@ impl CudaDevice {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
blas: Arc::new(blas),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn same_id(&self, rhs: &Self) -> bool {
|
||||
self.id == rhs.id
|
||||
}
|
||||
|
||||
pub(crate) fn ordinal(&self) -> usize {
|
||||
self.device.ordinal()
|
||||
}
|
||||
|
Reference in New Issue
Block a user