Reusing a single buffer (for now) to speed things up.

This commit is contained in:
Nicolas Patry
2023-11-11 12:50:25 +01:00
parent a52b71686b
commit e02f1912bb

View File

@ -7,8 +7,9 @@ use candle_metal_kernels::Kernels;
use core::mem;
use half::{bf16, f16};
use metal;
use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::Arc;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::ops::Deref;
use std::sync::{Arc, RwLock};
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@ -36,6 +37,7 @@ impl From<String> for MetalError {
pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
kernels: Arc<candle_metal_kernels::Kernels>,
}
@ -54,10 +56,6 @@ impl std::ops::Deref for MetalDevice {
}
impl MetalDevice {
// pub fn metal_device(&self) -> &metal::DeviceRef {
// self.device.as_ref()
// }
pub fn id(&self) -> NSUInteger {
self.registry_id()
}
@ -66,6 +64,19 @@ impl MetalDevice {
&self.command_queue
}
pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> {
self.command_buffer.read().unwrap()
}
pub fn wait_until_completed(&self) {
let mut old = self.command_buffer.write().unwrap();
old.commit();
old.wait_until_completed();
let command_buffer = self.command_queue.new_owned_command_buffer();
*old = command_buffer;
// self.command_buffer.replace_with(|_| command_buffer)
}
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
@ -105,6 +116,8 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
self.device.wait_until_completed();
match self.dtype {
DType::U8 => Ok(CpuStorage::U8(
self.buffer.read_to_vec(self.buffer.length() as usize / 1),
@ -138,7 +151,7 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype;
let mut buffer = device.new_buffer(el, self.dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
assert_eq!(dtype, DType::F32);
candle_metal_kernels::call_affine(
@ -168,8 +181,8 @@ impl BackendStorage for MetalStorage {
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
return Ok(Self {
buffer,
device: device.clone(),
@ -225,7 +238,7 @@ impl BackendStorage for MetalStorage {
}
let dtype = if return_index { DType::U32 } else { self.dtype };
let mut buffer = device.new_buffer(dst_el, dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
@ -237,8 +250,8 @@ impl BackendStorage for MetalStorage {
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
@ -256,7 +269,7 @@ impl BackendStorage for MetalStorage {
let shape = layout.shape();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
let command_buffer = device.command_buffer();
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
@ -280,8 +293,8 @@ impl BackendStorage for MetalStorage {
);
}
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
device: device.clone(),
@ -295,7 +308,7 @@ impl BackendStorage for MetalStorage {
let shape = layout.shape();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
let command_buffer = device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
use candle_metal_kernels::unary::contiguous;
@ -329,8 +342,8 @@ impl BackendStorage for MetalStorage {
} else {
todo!("TODO Implement the kernel calling {}", B::KERNEL);
}
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
@ -350,7 +363,7 @@ impl BackendStorage for MetalStorage {
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
let command_buffer = device.command_buffer();
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
{
@ -404,8 +417,8 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
@ -428,7 +441,7 @@ impl BackendStorage for MetalStorage {
let el = shape.elem_count();
let dtype = t.dtype;
let mut buffer = self.device.new_buffer(el, dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_where_cond_strided(
&device.device,
&command_buffer,
@ -447,8 +460,8 @@ impl BackendStorage for MetalStorage {
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
device,
@ -544,7 +557,7 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "is_u32_f32",
(left, right) => todo!("index select metal {left:?} {right:?}"),
};
let command_buffer = self.device.command_queue.new_command_buffer();
let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_index_select(
&device.device,
&command_buffer,
@ -558,8 +571,8 @@ impl BackendStorage for MetalStorage {
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
device: device.clone(),
@ -641,7 +654,7 @@ impl BackendStorage for MetalStorage {
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
let command_buffer = self.device.command_buffer();
for bi in 0..b {
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(
@ -689,14 +702,14 @@ impl BackendStorage for MetalStorage {
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
command_buffer,
command_buffer.deref(),
&left_matrix,
&right_matrix,
&result_matrix,
);
}
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer: out_buffer,
@ -712,7 +725,7 @@ impl BackendStorage for MetalStorage {
if el_count == 0 {
return Ok(());
}
let command_buffer = self.device.command_queue.new_command_buffer();
let command_buffer = self.device.command_buffer();
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
@ -733,8 +746,8 @@ impl BackendStorage for MetalStorage {
dst_offset,
)
.map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(())
}
}
@ -760,10 +773,12 @@ impl BackendDevice for MetalDevice {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer()));
let kernels = Arc::new(Kernels::new());
Ok(Self {
device,
command_queue,
command_buffer,
kernels,
})
}