mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Use write rather than try-write on the metal rw-locks. (#2162)
This commit is contained in:
@ -100,11 +100,11 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = command_buffer_lock.to_owned();
|
||||
let mut index = self
|
||||
.command_buffer_index
|
||||
.try_write()
|
||||
.write()
|
||||
.map_err(MetalError::from)?;
|
||||
if *index > self.compute_per_buffer {
|
||||
command_buffer.commit();
|
||||
@ -119,7 +119,7 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) -> Result<()> {
|
||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
@ -179,7 +179,7 @@ impl MetalDevice {
|
||||
size,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
let subbuffers = buffers
|
||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||
.or_insert(vec![]);
|
||||
@ -232,7 +232,7 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
fn drop_unused_buffers(&self) -> Result<()> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
@ -251,7 +251,7 @@ impl MetalDevice {
|
||||
option: MTLResourceOptions,
|
||||
_name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||
// Cloning also ensures we increment the strong count
|
||||
return Ok(b.clone());
|
||||
|
@ -6,7 +6,7 @@ use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};
|
||||
|
||||
mod device;
|
||||
pub use device::{DeviceId, MetalDevice};
|
||||
@ -36,6 +36,12 @@ impl<T> From<TryLockError<T>> for MetalError {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<PoisonError<T>> for MetalError {
|
||||
fn from(p: PoisonError<T>) -> Self {
|
||||
MetalError::LockError(LockError::Poisoned(p.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalError {
|
||||
|
Reference in New Issue
Block a user