From 01794dc16ef8d896933d61e9bd9b8a981cd51930 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 5 May 2024 07:22:46 +0200 Subject: [PATCH] Use write rather than try-write on the metal rw-locks. (#2162) --- candle-core/src/metal_backend/device.rs | 12 ++++++------ candle-core/src/metal_backend/mod.rs | 8 +++++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 44af7649..785fe621 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -100,11 +100,11 @@ impl MetalDevice { } pub fn command_buffer(&self) -> Result { - let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?; + let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?; let mut command_buffer = command_buffer_lock.to_owned(); let mut index = self .command_buffer_index - .try_write() + .write() .map_err(MetalError::from)?; if *index > self.compute_per_buffer { command_buffer.commit(); @@ -119,7 +119,7 @@ impl MetalDevice { } pub fn wait_until_completed(&self) -> Result<()> { - let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?; + let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?; match command_buffer.status() { metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled @@ -179,7 +179,7 @@ impl MetalDevice { size, MTLResourceOptions::StorageModeManaged, ); - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + let mut buffers = self.buffers.write().map_err(MetalError::from)?; let subbuffers = buffers .entry((size, MTLResourceOptions::StorageModeManaged)) .or_insert(vec![]); @@ -232,7 +232,7 @@ impl MetalDevice { } fn drop_unused_buffers(&self) -> Result<()> { - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + let mut buffers = self.buffers.write().map_err(MetalError::from)?; for subbuffers in buffers.values_mut() { let newbuffers = subbuffers .iter() @@ -251,7 +251,7 @@ impl MetalDevice { option: MTLResourceOptions, _name: &str, ) -> Result> { - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + let mut buffers = self.buffers.write().map_err(MetalError::from)?; if let Some(b) = self.find_available_buffer(size, option, &buffers) { // Cloning also ensures we increment the strong count return Ok(b.clone()); diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index e00566ca..9273eda8 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -6,7 +6,7 @@ use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels}; use metal::{Buffer, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::ffi::c_void; -use std::sync::{Arc, Mutex, RwLock, TryLockError}; +use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError}; mod device; pub use device::{DeviceId, MetalDevice}; @@ -36,6 +36,12 @@ impl From> for MetalError { } } +impl From> for MetalError { + fn from(p: PoisonError) -> Self { + MetalError::LockError(LockError::Poisoned(p.to_string())) + } +} + /// Metal related errors #[derive(thiserror::Error, Debug)] pub enum MetalError {