diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index b4a490cd..f570d2c5 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -8,7 +8,26 @@ use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::path::Path; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, RwLock, TryLockError}; + +/// Simple way to catch lock error without +/// depending on T +#[derive(thiserror::Error, Debug)] +pub enum LockError { + #[error("{0}")] + Poisoned(String), + #[error("Would block")] + WouldBlock, +} + +impl From> for MetalError { + fn from(value: TryLockError) -> Self { + match value { + TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())), + TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock), + } + } +} /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -24,6 +43,8 @@ pub enum MetalError { rhs_stride: Vec, mnk: (usize, usize, usize), }, + #[error("{0:?}")] + LockError(LockError), } impl From for MetalError { @@ -106,10 +127,13 @@ impl MetalDevice { &self.command_queue } - pub fn command_buffer(&self) -> CommandBuffer { - let mut command_buffer_lock = self.command_buffer.try_write().unwrap(); + pub fn command_buffer(&self) -> Result { + let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?; let mut command_buffer = command_buffer_lock.to_owned(); - let mut index = self.command_buffer_index.try_write().unwrap(); + let mut index = self + .command_buffer_index + .try_write() + .map_err(MetalError::from)?; if *index > self.compute_per_buffer { command_buffer.commit(); command_buffer = self.command_queue.new_command_buffer().to_owned(); @@ -117,11 +141,11 @@ impl MetalDevice { *index = 0; } *index += 1; - command_buffer + Ok(command_buffer) } - pub fn wait_until_completed(&self) { - let mut command_buffer = self.command_buffer.try_write().unwrap(); + pub fn wait_until_completed(&self) -> Result<()> { + let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?; match command_buffer.status() { metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled @@ -133,6 +157,7 @@ impl MetalDevice { command_buffer.commit(); command_buffer.wait_until_completed(); *command_buffer = self.command_queue.new_command_buffer().to_owned(); + Ok(()) } pub fn kernels(&self) -> &Kernels { @@ -148,7 +173,12 @@ impl MetalDevice { /// This means the buffer data cannot be read on the CPU directly. /// /// [`name`] is only used to keep track of the resource origin in case of bugs - pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc { + pub fn new_buffer( + &self, + element_count: usize, + dtype: DType, + name: &str, + ) -> Result> { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) } @@ -158,7 +188,7 @@ impl MetalDevice { /// This means the buffer can be read on the CPU but will require manual /// synchronization when the CPU memory is modified /// Used as a bridge to gather data back from the GPU - pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { + pub fn new_buffer_managed(&self, size: NSUInteger) -> Result> { self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") } @@ -168,7 +198,7 @@ impl MetalDevice { /// This method will block the computation because of the /// lack of lifetime management through the GPU. /// Internal comment for technical details. - pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { + pub fn new_buffer_with_data(&self, data: &[T]) -> Result> { let size = core::mem::size_of_val(data) as NSUInteger; let tmp = self.device.new_buffer_with_data( data.as_ptr() as *const core::ffi::c_void, @@ -179,8 +209,8 @@ impl MetalDevice { size, metal::MTLResourceOptions::StorageModePrivate, "with_data", - ); - let command_buffer = self.command_buffer(); + )?; + let command_buffer = self.command_buffer()?; command_buffer.set_label("with_data"); let blit = command_buffer.new_blit_command_encoder(); blit.wait_for_fence(&self.fence); @@ -196,8 +226,8 @@ impl MetalDevice { // Putting this wait forces the GPU buffer to be filled // with the actual data allowing the CPU storage todo // deallocate properly. - self.wait_until_completed(); - real + self.wait_until_completed()?; + Ok(real) } /// The critical allocator algorithm @@ -206,13 +236,13 @@ impl MetalDevice { size: NSUInteger, option: MTLResourceOptions, _name: &str, - ) -> Arc { - let mut buffers = self.buffers.try_write().unwrap(); + ) -> Result> { + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; let subbuffers = buffers.entry((size, option)).or_insert(vec![]); for sub in &mut *subbuffers { if Arc::strong_count(sub) == 1 { - return sub.clone(); + return Ok(sub.clone()); } } let new_buffer = self.device.new_buffer(size as NSUInteger, option); @@ -226,8 +256,7 @@ impl MetalDevice { .collect(); *subbuffers = newbuffers; } - - new_buffer + Ok(new_buffer) } /// Create a metal GPU capture trace on [`path`]. @@ -279,9 +308,9 @@ impl BackendStorage for MetalStorage { self.dtype ); } - let buffer = self.device.new_buffer_managed(self.buffer.length()); + let buffer = self.device.new_buffer_managed(self.buffer.length())?; { - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); @@ -290,7 +319,7 @@ impl BackendStorage for MetalStorage { blit.update_fence(&self.device.fence); blit.end_encoding(); } - self.device.wait_until_completed(); + self.device.wait_until_completed()?; match self.dtype { DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))), @@ -310,8 +339,8 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype, "affine"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(el, self.dtype, "affine")?; + let command_buffer = self.device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { DType::F32 => "affine_f32", @@ -361,8 +390,8 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype, "powf"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(el, self.dtype, "powf")?; + let command_buffer = self.device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { DType::F32 => "powf_f32", @@ -410,8 +439,8 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - let buffer = device.new_buffer(el, self.dtype, "elu"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(el, self.dtype, "elu")?; + let command_buffer = self.device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { DType::F32 => "elu_f32", @@ -497,8 +526,8 @@ impl BackendStorage for MetalStorage { if dtype == DType::U32 { crate::bail!("Implement return index reduce op"); } - let buffer = device.new_buffer(dst_el, dtype, "reduce"); - let command_buffer = self.device.command_buffer(); + let buffer = device.new_buffer(dst_el, dtype, "reduce")?; + let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_reduce_contiguous( &device.device, &command_buffer, @@ -523,8 +552,8 @@ impl BackendStorage for MetalStorage { let device = self.device(); let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype, "todtype"); - let command_buffer = device.command_buffer(); + let buffer = device.new_buffer(el_count, dtype, "todtype")?; + let command_buffer = device.command_buffer()?; if layout.is_contiguous() && layout.start_offset() == 0 { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", @@ -576,8 +605,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype, B::KERNEL); - let command_buffer = device.command_buffer(); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; + let command_buffer = device.command_buffer()?; command_buffer.set_label(B::KERNEL); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -681,8 +710,8 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let shape = lhs_l.shape(); let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype, B::KERNEL); - let command_buffer = device.command_buffer(); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; + let command_buffer = device.command_buffer()?; if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) && &B::KERNEL[..1] != "b" @@ -758,8 +787,8 @@ impl BackendStorage for MetalStorage { let dims = shape.dims(); let el = shape.elem_count(); let dtype = t.dtype; - let buffer = self.device.new_buffer(el, dtype, "where"); - let command_buffer = self.device.command_buffer(); + let buffer = self.device.new_buffer(el, dtype, "where")?; + let command_buffer = self.device.command_buffer()?; if t.dtype() != f.dtype() { crate::bail!("Invalid ternary different dtypes for values"); } @@ -875,13 +904,13 @@ impl BackendStorage for MetalStorage { let dst_el = ids_el * left_size * right_size; let dtype = self.dtype; let device = self.device(); - let buffer = device.new_buffer(dst_el, dtype, "index_select"); + let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", (left, right) => crate::bail!("index select metal {left:?} {right:?}"), }; - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -916,7 +945,7 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul"); + let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; let name = match self.dtype { DType::F32 => "sgemm", DType::F16 => "hgemm", @@ -925,7 +954,7 @@ impl BackendStorage for MetalStorage { } }; - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; command_buffer.set_label("matmul"); candle_metal_kernels::call_gemm( &self.device.device, @@ -946,7 +975,7 @@ impl BackendStorage for MetalStorage { } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { - let command_buffer = self.device.command_buffer(); + let command_buffer = self.device.command_buffer()?; if src_l.is_contiguous() && self.dtype == dst.dtype() { command_buffer.set_label("copy_contiguous"); let blit = command_buffer.new_blit_command_encoder(); @@ -1047,8 +1076,8 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros"); - let command_buffer = self.command_buffer(); + let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; + let command_buffer = self.command_buffer()?; command_buffer.set_label("zeros"); let blit = command_buffer.new_blit_command_encoder(); blit.wait_for_fence(&self.fence); @@ -1080,7 +1109,7 @@ impl BackendDevice for MetalDevice { CpuStorage::F16(storage) => self.new_buffer_with_data(storage), CpuStorage::F32(storage) => self.new_buffer_with_data(storage), CpuStorage::F64(storage) => self.new_buffer_with_data(storage), - }; + }?; Ok(Self::Storage::new( buffer.into(), self.clone(), diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index ca23f90e..94380f12 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -210,7 +210,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { ) -> Result<(candle::MetalStorage, Shape)> { use candle::{backend::BackendStorage, DType}; let device = storage.device(); - let command_buffer = device.command_buffer(); + let command_buffer = device.command_buffer()?; let kernels = device.kernels(); let name = match storage.dtype() { DType::F32 => "softmax_f32", @@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { let last_dim = layout.dims()[layout.shape().rank() - 1]; let elem_count = layout.shape().elem_count(); - let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax"); + let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; candle_metal_kernels::call_last_softmax( device.metal_device(), &command_buffer,