mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Use write rather than try-write on the metal rw-locks. (#2162)
This commit is contained in:
@ -100,11 +100,11 @@ impl MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
|
||||||
let mut command_buffer = command_buffer_lock.to_owned();
|
let mut command_buffer = command_buffer_lock.to_owned();
|
||||||
let mut index = self
|
let mut index = self
|
||||||
.command_buffer_index
|
.command_buffer_index
|
||||||
.try_write()
|
.write()
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
if *index > self.compute_per_buffer {
|
if *index > self.compute_per_buffer {
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -119,7 +119,7 @@ impl MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn wait_until_completed(&self) -> Result<()> {
|
pub fn wait_until_completed(&self) -> Result<()> {
|
||||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
|
||||||
match command_buffer.status() {
|
match command_buffer.status() {
|
||||||
metal::MTLCommandBufferStatus::Committed
|
metal::MTLCommandBufferStatus::Committed
|
||||||
| metal::MTLCommandBufferStatus::Scheduled
|
| metal::MTLCommandBufferStatus::Scheduled
|
||||||
@ -179,7 +179,7 @@ impl MetalDevice {
|
|||||||
size,
|
size,
|
||||||
MTLResourceOptions::StorageModeManaged,
|
MTLResourceOptions::StorageModeManaged,
|
||||||
);
|
);
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
let subbuffers = buffers
|
let subbuffers = buffers
|
||||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||||
.or_insert(vec![]);
|
.or_insert(vec![]);
|
||||||
@ -232,7 +232,7 @@ impl MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn drop_unused_buffers(&self) -> Result<()> {
|
fn drop_unused_buffers(&self) -> Result<()> {
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
for subbuffers in buffers.values_mut() {
|
for subbuffers in buffers.values_mut() {
|
||||||
let newbuffers = subbuffers
|
let newbuffers = subbuffers
|
||||||
.iter()
|
.iter()
|
||||||
@ -251,7 +251,7 @@ impl MetalDevice {
|
|||||||
option: MTLResourceOptions,
|
option: MTLResourceOptions,
|
||||||
_name: &str,
|
_name: &str,
|
||||||
) -> Result<Arc<Buffer>> {
|
) -> Result<Arc<Buffer>> {
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||||
// Cloning also ensures we increment the strong count
|
// Cloning also ensures we increment the strong count
|
||||||
return Ok(b.clone());
|
return Ok(b.clone());
|
||||||
|
@ -6,7 +6,7 @@ use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
|||||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};
|
||||||
|
|
||||||
mod device;
|
mod device;
|
||||||
pub use device::{DeviceId, MetalDevice};
|
pub use device::{DeviceId, MetalDevice};
|
||||||
@ -36,6 +36,12 @@ impl<T> From<TryLockError<T>> for MetalError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> From<PoisonError<T>> for MetalError {
|
||||||
|
fn from(p: PoisonError<T>) -> Self {
|
||||||
|
MetalError::LockError(LockError::Poisoned(p.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum MetalError {
|
pub enum MetalError {
|
||||||
|
Reference in New Issue
Block a user