Remove unwrap().

This commit is contained in:
Nicolas Patry
2023-12-15 12:23:28 +01:00
parent 8b5059e951
commit aa04015098
2 changed files with 77 additions and 48 deletions

View File

@ -8,7 +8,26 @@ use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap; use std::collections::HashMap;
use std::path::Path; use std::path::Path;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock, TryLockError};
/// Simple way to catch lock error without
/// depending on T
#[derive(thiserror::Error, Debug)]
pub enum LockError {
#[error("{0}")]
Poisoned(String),
#[error("Would block")]
WouldBlock,
}
impl<T> From<TryLockError<T>> for MetalError {
fn from(value: TryLockError<T>) -> Self {
match value {
TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())),
TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock),
}
}
}
/// Metal related errors /// Metal related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -24,6 +43,8 @@ pub enum MetalError {
rhs_stride: Vec<usize>, rhs_stride: Vec<usize>,
mnk: (usize, usize, usize), mnk: (usize, usize, usize),
}, },
#[error("{0:?}")]
LockError(LockError),
} }
impl From<String> for MetalError { impl From<String> for MetalError {
@ -106,10 +127,13 @@ impl MetalDevice {
&self.command_queue &self.command_queue
} }
pub fn command_buffer(&self) -> CommandBuffer { pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut command_buffer_lock = self.command_buffer.try_write().unwrap(); let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned(); let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self.command_buffer_index.try_write().unwrap(); let mut index = self
.command_buffer_index
.try_write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer { if *index > self.compute_per_buffer {
command_buffer.commit(); command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned(); command_buffer = self.command_queue.new_command_buffer().to_owned();
@ -117,11 +141,11 @@ impl MetalDevice {
*index = 0; *index = 0;
} }
*index += 1; *index += 1;
command_buffer Ok(command_buffer)
} }
pub fn wait_until_completed(&self) { pub fn wait_until_completed(&self) -> Result<()> {
let mut command_buffer = self.command_buffer.try_write().unwrap(); let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
match command_buffer.status() { match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled | metal::MTLCommandBufferStatus::Scheduled
@ -133,6 +157,7 @@ impl MetalDevice {
command_buffer.commit(); command_buffer.commit();
command_buffer.wait_until_completed(); command_buffer.wait_until_completed();
*command_buffer = self.command_queue.new_command_buffer().to_owned(); *command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
} }
pub fn kernels(&self) -> &Kernels { pub fn kernels(&self) -> &Kernels {
@ -148,7 +173,12 @@ impl MetalDevice {
/// This means the buffer data cannot be read on the CPU directly. /// This means the buffer data cannot be read on the CPU directly.
/// ///
/// [`name`] is only used to keep track of the resource origin in case of bugs /// [`name`] is only used to keep track of the resource origin in case of bugs
pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc<Buffer> { pub fn new_buffer(
&self,
element_count: usize,
dtype: DType,
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger; let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
} }
@ -158,7 +188,7 @@ impl MetalDevice {
/// This means the buffer can be read on the CPU but will require manual /// This means the buffer can be read on the CPU but will require manual
/// synchronization when the CPU memory is modified /// synchronization when the CPU memory is modified
/// Used as a bridge to gather data back from the GPU /// Used as a bridge to gather data back from the GPU
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> { pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
} }
@ -168,7 +198,7 @@ impl MetalDevice {
/// This method will block the computation because of the /// This method will block the computation because of the
/// lack of lifetime management through the GPU. /// lack of lifetime management through the GPU.
/// Internal comment for technical details. /// Internal comment for technical details.
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> { pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data) as NSUInteger; let size = core::mem::size_of_val(data) as NSUInteger;
let tmp = self.device.new_buffer_with_data( let tmp = self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void, data.as_ptr() as *const core::ffi::c_void,
@ -179,8 +209,8 @@ impl MetalDevice {
size, size,
metal::MTLResourceOptions::StorageModePrivate, metal::MTLResourceOptions::StorageModePrivate,
"with_data", "with_data",
); )?;
let command_buffer = self.command_buffer(); let command_buffer = self.command_buffer()?;
command_buffer.set_label("with_data"); command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence); blit.wait_for_fence(&self.fence);
@ -196,8 +226,8 @@ impl MetalDevice {
// Putting this wait forces the GPU buffer to be filled // Putting this wait forces the GPU buffer to be filled
// with the actual data allowing the CPU storage todo // with the actual data allowing the CPU storage todo
// deallocate properly. // deallocate properly.
self.wait_until_completed(); self.wait_until_completed()?;
real Ok(real)
} }
/// The critical allocator algorithm /// The critical allocator algorithm
@ -206,13 +236,13 @@ impl MetalDevice {
size: NSUInteger, size: NSUInteger,
option: MTLResourceOptions, option: MTLResourceOptions,
_name: &str, _name: &str,
) -> Arc<Buffer> { ) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.try_write().unwrap(); let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let subbuffers = buffers.entry((size, option)).or_insert(vec![]); let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
for sub in &mut *subbuffers { for sub in &mut *subbuffers {
if Arc::strong_count(sub) == 1 { if Arc::strong_count(sub) == 1 {
return sub.clone(); return Ok(sub.clone());
} }
} }
let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = self.device.new_buffer(size as NSUInteger, option);
@ -226,8 +256,7 @@ impl MetalDevice {
.collect(); .collect();
*subbuffers = newbuffers; *subbuffers = newbuffers;
} }
Ok(new_buffer)
new_buffer
} }
/// Create a metal GPU capture trace on [`path`]. /// Create a metal GPU capture trace on [`path`].
@ -279,9 +308,9 @@ impl BackendStorage for MetalStorage {
self.dtype self.dtype
); );
} }
let buffer = self.device.new_buffer_managed(self.buffer.length()); let buffer = self.device.new_buffer_managed(self.buffer.length())?;
{ {
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu"); command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu"); blit.set_label("blit_to_cpu");
@ -290,7 +319,7 @@ impl BackendStorage for MetalStorage {
blit.update_fence(&self.device.fence); blit.update_fence(&self.device.fence);
blit.end_encoding(); blit.end_encoding();
} }
self.device.wait_until_completed(); self.device.wait_until_completed()?;
match self.dtype { match self.dtype {
DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))), DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
@ -310,8 +339,8 @@ impl BackendStorage for MetalStorage {
let el = shape.elem_count(); let el = shape.elem_count();
let dtype = self.dtype; let dtype = self.dtype;
let buffer = device.new_buffer(el, self.dtype, "affine"); let buffer = device.new_buffer(el, self.dtype, "affine")?;
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "affine_f32", DType::F32 => "affine_f32",
@ -361,8 +390,8 @@ impl BackendStorage for MetalStorage {
let el = shape.elem_count(); let el = shape.elem_count();
let dtype = self.dtype; let dtype = self.dtype;
let buffer = device.new_buffer(el, self.dtype, "powf"); let buffer = device.new_buffer(el, self.dtype, "powf")?;
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "powf_f32", DType::F32 => "powf_f32",
@ -410,8 +439,8 @@ impl BackendStorage for MetalStorage {
let el = shape.elem_count(); let el = shape.elem_count();
let dtype = self.dtype; let dtype = self.dtype;
let buffer = device.new_buffer(el, self.dtype, "elu"); let buffer = device.new_buffer(el, self.dtype, "elu")?;
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "elu_f32", DType::F32 => "elu_f32",
@ -497,8 +526,8 @@ impl BackendStorage for MetalStorage {
if dtype == DType::U32 { if dtype == DType::U32 {
crate::bail!("Implement return index reduce op"); crate::bail!("Implement return index reduce op");
} }
let buffer = device.new_buffer(dst_el, dtype, "reduce"); let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_reduce_contiguous( candle_metal_kernels::call_reduce_contiguous(
&device.device, &device.device,
&command_buffer, &command_buffer,
@ -523,8 +552,8 @@ impl BackendStorage for MetalStorage {
let device = self.device(); let device = self.device();
let shape = layout.shape(); let shape = layout.shape();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "todtype"); let buffer = device.new_buffer(el_count, dtype, "todtype")?;
let command_buffer = device.command_buffer(); let command_buffer = device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
let kernel_name = match (self.dtype, dtype) { let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::F32) => "cast_u32_f32",
@ -576,8 +605,8 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype; let dtype = self.dtype;
let shape = layout.shape(); let shape = layout.shape();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, B::KERNEL); let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
let command_buffer = device.command_buffer(); let command_buffer = device.command_buffer()?;
command_buffer.set_label(B::KERNEL); command_buffer.set_label(B::KERNEL);
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
use candle_metal_kernels::unary::contiguous; use candle_metal_kernels::unary::contiguous;
@ -681,8 +710,8 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype; let dtype = self.dtype;
let shape = lhs_l.shape(); let shape = lhs_l.shape();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, B::KERNEL); let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
let command_buffer = device.command_buffer(); let command_buffer = device.command_buffer()?;
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
&& &B::KERNEL[..1] != "b" && &B::KERNEL[..1] != "b"
@ -758,8 +787,8 @@ impl BackendStorage for MetalStorage {
let dims = shape.dims(); let dims = shape.dims();
let el = shape.elem_count(); let el = shape.elem_count();
let dtype = t.dtype; let dtype = t.dtype;
let buffer = self.device.new_buffer(el, dtype, "where"); let buffer = self.device.new_buffer(el, dtype, "where")?;
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer()?;
if t.dtype() != f.dtype() { if t.dtype() != f.dtype() {
crate::bail!("Invalid ternary different dtypes for values"); crate::bail!("Invalid ternary different dtypes for values");
} }
@ -875,13 +904,13 @@ impl BackendStorage for MetalStorage {
let dst_el = ids_el * left_size * right_size; let dst_el = ids_el * left_size * right_size;
let dtype = self.dtype; let dtype = self.dtype;
let device = self.device(); let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select"); let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) { let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::F16) => "is_u32_f16",
(left, right) => crate::bail!("index select metal {left:?} {right:?}"), (left, right) => crate::bail!("index select metal {left:?} {right:?}"),
}; };
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_index_select( candle_metal_kernels::call_index_select(
&device.device, &device.device,
&command_buffer, &command_buffer,
@ -916,7 +945,7 @@ impl BackendStorage for MetalStorage {
lhs_l: &Layout, lhs_l: &Layout,
rhs_l: &Layout, rhs_l: &Layout,
) -> Result<Self> { ) -> Result<Self> {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul"); let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "sgemm", DType::F32 => "sgemm",
DType::F16 => "hgemm", DType::F16 => "hgemm",
@ -925,7 +954,7 @@ impl BackendStorage for MetalStorage {
} }
}; };
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul"); command_buffer.set_label("matmul");
candle_metal_kernels::call_gemm( candle_metal_kernels::call_gemm(
&self.device.device, &self.device.device,
@ -946,7 +975,7 @@ impl BackendStorage for MetalStorage {
} }
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer()?;
if src_l.is_contiguous() && self.dtype == dst.dtype() { if src_l.is_contiguous() && self.dtype == dst.dtype() {
command_buffer.set_label("copy_contiguous"); command_buffer.set_label("copy_contiguous");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
@ -1047,8 +1076,8 @@ impl BackendDevice for MetalDevice {
} }
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros"); let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
let command_buffer = self.command_buffer(); let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros"); command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence); blit.wait_for_fence(&self.fence);
@ -1080,7 +1109,7 @@ impl BackendDevice for MetalDevice {
CpuStorage::F16(storage) => self.new_buffer_with_data(storage), CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
CpuStorage::F32(storage) => self.new_buffer_with_data(storage), CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
CpuStorage::F64(storage) => self.new_buffer_with_data(storage), CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
}; }?;
Ok(Self::Storage::new( Ok(Self::Storage::new(
buffer.into(), buffer.into(),
self.clone(), self.clone(),

View File

@ -210,7 +210,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
) -> Result<(candle::MetalStorage, Shape)> { ) -> Result<(candle::MetalStorage, Shape)> {
use candle::{backend::BackendStorage, DType}; use candle::{backend::BackendStorage, DType};
let device = storage.device(); let device = storage.device();
let command_buffer = device.command_buffer(); let command_buffer = device.command_buffer()?;
let kernels = device.kernels(); let kernels = device.kernels();
let name = match storage.dtype() { let name = match storage.dtype() {
DType::F32 => "softmax_f32", DType::F32 => "softmax_f32",
@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
let last_dim = layout.dims()[layout.shape().rank() - 1]; let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count(); let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax"); let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
candle_metal_kernels::call_last_softmax( candle_metal_kernels::call_last_softmax(
device.metal_device(), device.metal_device(),
&command_buffer, &command_buffer,