mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Really unique identifier for metal device ids. (#1932)
* Really unique identifier for metal device ids. * Same device.
This commit is contained in:
@ -10,6 +10,19 @@ use std::ffi::c_void;
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError};
|
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError};
|
||||||
|
|
||||||
|
/// Unique identifier for cuda devices.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
|
pub struct DeviceId(usize);
|
||||||
|
|
||||||
|
impl DeviceId {
|
||||||
|
fn new() -> Self {
|
||||||
|
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||||
|
use std::sync::atomic;
|
||||||
|
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||||
|
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Simple way to catch lock error without
|
/// Simple way to catch lock error without
|
||||||
/// depending on T
|
/// depending on T
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -64,6 +77,10 @@ type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct MetalDevice {
|
pub struct MetalDevice {
|
||||||
|
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
|
||||||
|
/// the device itself.
|
||||||
|
id: DeviceId,
|
||||||
|
|
||||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||||
device: metal::Device,
|
device: metal::Device,
|
||||||
|
|
||||||
@ -108,7 +125,7 @@ pub struct MetalDevice {
|
|||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
impl std::fmt::Debug for MetalDevice {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "MetalDevice({:?})", self.device.registry_id())
|
write!(f, "MetalDevice({:?})", self.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,8 +138,8 @@ impl std::ops::Deref for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MetalDevice {
|
impl MetalDevice {
|
||||||
pub fn id(&self) -> NSUInteger {
|
pub fn id(&self) -> DeviceId {
|
||||||
self.registry_id()
|
self.id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn metal_device(&self) -> &metal::Device {
|
pub fn metal_device(&self) -> &metal::Device {
|
||||||
@ -1117,8 +1134,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
padding: params.padding,
|
padding: params.padding,
|
||||||
output_padding: params.output_padding,
|
output_padding: params.output_padding,
|
||||||
c_out: params.c_out,
|
c_out: params.c_out,
|
||||||
out_h: out_h,
|
out_h,
|
||||||
out_w: out_w,
|
out_w,
|
||||||
b_size: params.b_size,
|
b_size: params.b_size,
|
||||||
input_dims: l.dims(),
|
input_dims: l.dims(),
|
||||||
input_stride: l.stride(),
|
input_stride: l.stride(),
|
||||||
@ -1867,6 +1884,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
MTLResourceOptions::StorageModeManaged,
|
MTLResourceOptions::StorageModeManaged,
|
||||||
)));
|
)));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
id: DeviceId::new(),
|
||||||
device,
|
device,
|
||||||
command_queue,
|
command_queue,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
@ -1885,7 +1903,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn same_device(&self, rhs: &Self) -> bool {
|
fn same_device(&self, rhs: &Self) -> bool {
|
||||||
self.device.registry_id() == rhs.device.registry_id()
|
self.id == rhs.id
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||||
|
@ -44,9 +44,19 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||||
let lhs = self.device().location();
|
let lhs_device = self.device();
|
||||||
let rhs = rhs.device().location();
|
let rhs_device = rhs.device();
|
||||||
if lhs != rhs {
|
let lhs = lhs_device.location();
|
||||||
|
let rhs = rhs_device.location();
|
||||||
|
let same_device = if self.device().is_metal() {
|
||||||
|
// On metal, we require the device to be exactly the same rather than
|
||||||
|
// having the same location. In cuda this is not necessary as all CudaDevice on the
|
||||||
|
// same GPU will use the same cuda stream.
|
||||||
|
lhs_device.same_device(&rhs_device)
|
||||||
|
} else {
|
||||||
|
lhs == rhs
|
||||||
|
};
|
||||||
|
if !same_device {
|
||||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
|
Reference in New Issue
Block a user