mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Reusing a single buffer (for now) to speed things up.
This commit is contained in:
@ -7,8 +7,9 @@ use candle_metal_kernels::Kernels;
|
||||
use core::mem;
|
||||
use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::sync::Arc;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::ops::Deref;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -36,6 +37,7 @@ impl From<String> for MetalError {
|
||||
pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
command_queue: metal::CommandQueue,
|
||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
}
|
||||
|
||||
@ -54,10 +56,6 @@ impl std::ops::Deref for MetalDevice {
|
||||
}
|
||||
|
||||
impl MetalDevice {
|
||||
// pub fn metal_device(&self) -> &metal::DeviceRef {
|
||||
// self.device.as_ref()
|
||||
// }
|
||||
|
||||
pub fn id(&self) -> NSUInteger {
|
||||
self.registry_id()
|
||||
}
|
||||
@ -66,6 +64,19 @@ impl MetalDevice {
|
||||
&self.command_queue
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> {
|
||||
self.command_buffer.read().unwrap()
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) {
|
||||
let mut old = self.command_buffer.write().unwrap();
|
||||
old.commit();
|
||||
old.wait_until_completed();
|
||||
let command_buffer = self.command_queue.new_owned_command_buffer();
|
||||
*old = command_buffer;
|
||||
// self.command_buffer.replace_with(|_| command_buffer)
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
&self.kernels
|
||||
}
|
||||
@ -105,6 +116,8 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
self.device.wait_until_completed();
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 1),
|
||||
@ -138,7 +151,7 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = self.dtype;
|
||||
|
||||
let mut buffer = device.new_buffer(el, self.dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
assert_eq!(dtype, DType::F32);
|
||||
candle_metal_kernels::call_affine(
|
||||
@ -168,8 +181,8 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
return Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -225,7 +238,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -237,8 +250,8 @@ impl BackendStorage for MetalStorage {
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
@ -256,7 +269,7 @@ impl BackendStorage for MetalStorage {
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
let command_buffer = device.command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
@ -280,8 +293,8 @@ impl BackendStorage for MetalStorage {
|
||||
);
|
||||
}
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -295,7 +308,7 @@ impl BackendStorage for MetalStorage {
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
let command_buffer = device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
@ -329,8 +342,8 @@ impl BackendStorage for MetalStorage {
|
||||
} else {
|
||||
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
@ -350,7 +363,7 @@ impl BackendStorage for MetalStorage {
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
let command_buffer = device.command_buffer();
|
||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
{
|
||||
@ -404,8 +417,8 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
@ -428,7 +441,7 @@ impl BackendStorage for MetalStorage {
|
||||
let el = shape.elem_count();
|
||||
let dtype = t.dtype;
|
||||
let mut buffer = self.device.new_buffer(el, dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -447,8 +460,8 @@ impl BackendStorage for MetalStorage {
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device,
|
||||
@ -544,7 +557,7 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
||||
};
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_index_select(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -558,8 +571,8 @@ impl BackendStorage for MetalStorage {
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -641,7 +654,7 @@ impl BackendStorage for MetalStorage {
|
||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||
|
||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let command_buffer = self.device.command_buffer();
|
||||
for bi in 0..b {
|
||||
// Create matrix objects
|
||||
let left_matrix = Matrix::init_with_buffer_descriptor(
|
||||
@ -689,14 +702,14 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
// Encode kernel to command buffer
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
command_buffer,
|
||||
command_buffer.deref(),
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
|
||||
Ok(Self {
|
||||
buffer: out_buffer,
|
||||
@ -712,7 +725,7 @@ impl BackendStorage for MetalStorage {
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||
@ -733,8 +746,8 @@ impl BackendStorage for MetalStorage {
|
||||
dst_offset,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -760,10 +773,12 @@ impl BackendDevice for MetalDevice {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer()));
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
Ok(Self {
|
||||
device,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
kernels,
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user