mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Remove unwrap()
.
This commit is contained in:
@ -8,7 +8,26 @@ use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::sync::{Arc, RwLock, TryLockError};
|
||||
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum LockError {
|
||||
#[error("{0}")]
|
||||
Poisoned(String),
|
||||
#[error("Would block")]
|
||||
WouldBlock,
|
||||
}
|
||||
|
||||
impl<T> From<TryLockError<T>> for MetalError {
|
||||
fn from(value: TryLockError<T>) -> Self {
|
||||
match value {
|
||||
TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())),
|
||||
TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -24,6 +43,8 @@ pub enum MetalError {
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
#[error("{0:?}")]
|
||||
LockError(LockError),
|
||||
}
|
||||
|
||||
impl From<String> for MetalError {
|
||||
@ -106,10 +127,13 @@ impl MetalDevice {
|
||||
&self.command_queue
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> CommandBuffer {
|
||||
let mut command_buffer_lock = self.command_buffer.try_write().unwrap();
|
||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = command_buffer_lock.to_owned();
|
||||
let mut index = self.command_buffer_index.try_write().unwrap();
|
||||
let mut index = self
|
||||
.command_buffer_index
|
||||
.try_write()
|
||||
.map_err(MetalError::from)?;
|
||||
if *index > self.compute_per_buffer {
|
||||
command_buffer.commit();
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
@ -117,11 +141,11 @@ impl MetalDevice {
|
||||
*index = 0;
|
||||
}
|
||||
*index += 1;
|
||||
command_buffer
|
||||
Ok(command_buffer)
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) {
|
||||
let mut command_buffer = self.command_buffer.try_write().unwrap();
|
||||
pub fn wait_until_completed(&self) -> Result<()> {
|
||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
@ -133,6 +157,7 @@ impl MetalDevice {
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
@ -148,7 +173,12 @@ impl MetalDevice {
|
||||
/// This means the buffer data cannot be read on the CPU directly.
|
||||
///
|
||||
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc<Buffer> {
|
||||
pub fn new_buffer(
|
||||
&self,
|
||||
element_count: usize,
|
||||
dtype: DType,
|
||||
name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||
}
|
||||
@ -158,7 +188,7 @@ impl MetalDevice {
|
||||
/// This means the buffer can be read on the CPU but will require manual
|
||||
/// synchronization when the CPU memory is modified
|
||||
/// Used as a bridge to gather data back from the GPU
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
}
|
||||
|
||||
@ -168,7 +198,7 @@ impl MetalDevice {
|
||||
/// This method will block the computation because of the
|
||||
/// lack of lifetime management through the GPU.
|
||||
/// Internal comment for technical details.
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let tmp = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
@ -179,8 +209,8 @@ impl MetalDevice {
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModePrivate,
|
||||
"with_data",
|
||||
);
|
||||
let command_buffer = self.command_buffer();
|
||||
)?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
@ -196,8 +226,8 @@ impl MetalDevice {
|
||||
// Putting this wait forces the GPU buffer to be filled
|
||||
// with the actual data allowing the CPU storage todo
|
||||
// deallocate properly.
|
||||
self.wait_until_completed();
|
||||
real
|
||||
self.wait_until_completed()?;
|
||||
Ok(real)
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
@ -206,13 +236,13 @@ impl MetalDevice {
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
_name: &str,
|
||||
) -> Arc<Buffer> {
|
||||
let mut buffers = self.buffers.try_write().unwrap();
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||
|
||||
for sub in &mut *subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
return sub.clone();
|
||||
return Ok(sub.clone());
|
||||
}
|
||||
}
|
||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||
@ -226,8 +256,7 @@ impl MetalDevice {
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
|
||||
new_buffer
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
/// Create a metal GPU capture trace on [`path`].
|
||||
@ -279,9 +308,9 @@ impl BackendStorage for MetalStorage {
|
||||
self.dtype
|
||||
);
|
||||
}
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
{
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
@ -290,7 +319,7 @@ impl BackendStorage for MetalStorage {
|
||||
blit.update_fence(&self.device.fence);
|
||||
blit.end_encoding();
|
||||
}
|
||||
self.device.wait_until_completed();
|
||||
self.device.wait_until_completed()?;
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
|
||||
@ -310,8 +339,8 @@ impl BackendStorage for MetalStorage {
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "affine");
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let buffer = device.new_buffer(el, self.dtype, "affine")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_f32",
|
||||
@ -361,8 +390,8 @@ impl BackendStorage for MetalStorage {
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "powf");
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let buffer = device.new_buffer(el, self.dtype, "powf")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_f32",
|
||||
@ -410,8 +439,8 @@ impl BackendStorage for MetalStorage {
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "elu");
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let buffer = device.new_buffer(el, self.dtype, "elu")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_f32",
|
||||
@ -497,8 +526,8 @@ impl BackendStorage for MetalStorage {
|
||||
if dtype == DType::U32 {
|
||||
crate::bail!("Implement return index reduce op");
|
||||
}
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce");
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -523,8 +552,8 @@ impl BackendStorage for MetalStorage {
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, "todtype");
|
||||
let command_buffer = device.command_buffer();
|
||||
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
@ -576,8 +605,8 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = self.dtype;
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL);
|
||||
let command_buffer = device.command_buffer();
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label(B::KERNEL);
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
@ -681,8 +710,8 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = self.dtype;
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL);
|
||||
let command_buffer = device.command_buffer();
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
&& &B::KERNEL[..1] != "b"
|
||||
@ -758,8 +787,8 @@ impl BackendStorage for MetalStorage {
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let dtype = t.dtype;
|
||||
let buffer = self.device.new_buffer(el, dtype, "where");
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let buffer = self.device.new_buffer(el, dtype, "where")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if t.dtype() != f.dtype() {
|
||||
crate::bail!("Invalid ternary different dtypes for values");
|
||||
}
|
||||
@ -875,13 +904,13 @@ impl BackendStorage for MetalStorage {
|
||||
let dst_el = ids_el * left_size * right_size;
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select");
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
|
||||
};
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_index_select(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -916,7 +945,7 @@ impl BackendStorage for MetalStorage {
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul");
|
||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "sgemm",
|
||||
DType::F16 => "hgemm",
|
||||
@ -925,7 +954,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
};
|
||||
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("matmul");
|
||||
candle_metal_kernels::call_gemm(
|
||||
&self.device.device,
|
||||
@ -946,7 +975,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if src_l.is_contiguous() && self.dtype == dst.dtype() {
|
||||
command_buffer.set_label("copy_contiguous");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
@ -1047,8 +1076,8 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros");
|
||||
let command_buffer = self.command_buffer();
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
@ -1080,7 +1109,7 @@ impl BackendDevice for MetalDevice {
|
||||
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||
};
|
||||
}?;
|
||||
Ok(Self::Storage::new(
|
||||
buffer.into(),
|
||||
self.clone(),
|
||||
|
@ -210,7 +210,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::{backend::BackendStorage, DType};
|
||||
let device = storage.device();
|
||||
let command_buffer = device.command_buffer();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
let name = match storage.dtype() {
|
||||
DType::F32 => "softmax_f32",
|
||||
@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax");
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
|
Reference in New Issue
Block a user