diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 73a141ea..bf501e24 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -10,6 +10,19 @@ use std::ffi::c_void; use std::path::Path; use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError}; +/// Unique identifier for cuda devices. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct DeviceId(usize); + +impl DeviceId { + fn new() -> Self { + // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 + use std::sync::atomic; + static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); + Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) + } +} + /// Simple way to catch lock error without /// depending on T #[derive(thiserror::Error, Debug)] @@ -64,6 +77,10 @@ type AllocatedBuffers = Arc>; #[derive(Clone)] pub struct MetalDevice { + /// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than + /// the device itself. + id: DeviceId, + /// Raw metal device: device: metal::Device, @@ -108,7 +125,7 @@ pub struct MetalDevice { impl std::fmt::Debug for MetalDevice { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "MetalDevice({:?})", self.device.registry_id()) + write!(f, "MetalDevice({:?})", self.id) } } @@ -121,8 +138,8 @@ impl std::ops::Deref for MetalDevice { } impl MetalDevice { - pub fn id(&self) -> NSUInteger { - self.registry_id() + pub fn id(&self) -> DeviceId { + self.id } pub fn metal_device(&self) -> &metal::Device { @@ -1117,8 +1134,8 @@ impl BackendStorage for MetalStorage { padding: params.padding, output_padding: params.output_padding, c_out: params.c_out, - out_h: out_h, - out_w: out_w, + out_h, + out_w, b_size: params.b_size, input_dims: l.dims(), input_stride: l.stride(), @@ -1867,6 +1884,7 @@ impl BackendDevice for MetalDevice { MTLResourceOptions::StorageModeManaged, ))); Ok(Self { + id: DeviceId::new(), device, command_queue, command_buffer, @@ -1885,7 +1903,7 @@ impl BackendDevice for MetalDevice { } fn same_device(&self, rhs: &Self) -> bool { - self.device.registry_id() == rhs.device.registry_id() + self.id == rhs.id } unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 36620dd9..8a0637e3 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -44,9 +44,19 @@ impl Storage { } pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> { - let lhs = self.device().location(); - let rhs = rhs.device().location(); - if lhs != rhs { + let lhs_device = self.device(); + let rhs_device = rhs.device(); + let lhs = lhs_device.location(); + let rhs = rhs_device.location(); + let same_device = if self.device().is_metal() { + // On metal, we require the device to be exactly the same rather than + // having the same location. In cuda this is not necessary as all CudaDevice on the + // same GPU will use the same cuda stream. + lhs_device.same_device(&rhs_device) + } else { + lhs == rhs + }; + if !same_device { Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt()) } else { Ok(())