Reusing a single buffer (for now) to speed things up.

This commit is contained in:
Nicolas Patry
2023-11-11 12:50:25 +01:00
parent a52b71686b
commit e02f1912bb

View File

@ -7,8 +7,9 @@ use candle_metal_kernels::Kernels;
use core::mem; use core::mem;
use half::{bf16, f16}; use half::{bf16, f16};
use metal; use metal;
use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::Arc; use std::ops::Deref;
use std::sync::{Arc, RwLock};
/// Metal related errors /// Metal related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -36,6 +37,7 @@ impl From<String> for MetalError {
pub struct MetalDevice { pub struct MetalDevice {
device: metal::Device, device: metal::Device,
command_queue: metal::CommandQueue, command_queue: metal::CommandQueue,
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
} }
@ -54,10 +56,6 @@ impl std::ops::Deref for MetalDevice {
} }
impl MetalDevice { impl MetalDevice {
// pub fn metal_device(&self) -> &metal::DeviceRef {
// self.device.as_ref()
// }
pub fn id(&self) -> NSUInteger { pub fn id(&self) -> NSUInteger {
self.registry_id() self.registry_id()
} }
@ -66,6 +64,19 @@ impl MetalDevice {
&self.command_queue &self.command_queue
} }
pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> {
self.command_buffer.read().unwrap()
}
pub fn wait_until_completed(&self) {
let mut old = self.command_buffer.write().unwrap();
old.commit();
old.wait_until_completed();
let command_buffer = self.command_queue.new_owned_command_buffer();
*old = command_buffer;
// self.command_buffer.replace_with(|_| command_buffer)
}
pub fn kernels(&self) -> &Kernels { pub fn kernels(&self) -> &Kernels {
&self.kernels &self.kernels
} }
@ -105,6 +116,8 @@ impl BackendStorage for MetalStorage {
} }
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
self.device.wait_until_completed();
match self.dtype { match self.dtype {
DType::U8 => Ok(CpuStorage::U8( DType::U8 => Ok(CpuStorage::U8(
self.buffer.read_to_vec(self.buffer.length() as usize / 1), self.buffer.read_to_vec(self.buffer.length() as usize / 1),
@ -138,7 +151,7 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype; let dtype = self.dtype;
let mut buffer = device.new_buffer(el, self.dtype); let mut buffer = device.new_buffer(el, self.dtype);
let command_buffer = self.device.command_queue.new_command_buffer(); let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
assert_eq!(dtype, DType::F32); assert_eq!(dtype, DType::F32);
candle_metal_kernels::call_affine( candle_metal_kernels::call_affine(
@ -168,8 +181,8 @@ impl BackendStorage for MetalStorage {
) )
.unwrap(); .unwrap();
} }
command_buffer.commit(); // command_buffer.commit();
command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled();
return Ok(Self { return Ok(Self {
buffer, buffer,
device: device.clone(), device: device.clone(),
@ -225,7 +238,7 @@ impl BackendStorage for MetalStorage {
} }
let dtype = if return_index { DType::U32 } else { self.dtype }; let dtype = if return_index { DType::U32 } else { self.dtype };
let mut buffer = device.new_buffer(dst_el, dtype); let mut buffer = device.new_buffer(dst_el, dtype);
let command_buffer = self.device.command_queue.new_command_buffer(); let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_reduce_contiguous( candle_metal_kernels::call_reduce_contiguous(
&device.device, &device.device,
&command_buffer, &command_buffer,
@ -237,8 +250,8 @@ impl BackendStorage for MetalStorage {
&mut buffer, &mut buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
command_buffer.commit(); // command_buffer.commit();
command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled();
Ok(Self { Ok(Self {
buffer, buffer,
@ -256,7 +269,7 @@ impl BackendStorage for MetalStorage {
let shape = layout.shape(); let shape = layout.shape();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer(); let command_buffer = device.command_buffer();
if layout.is_contiguous() { if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) { let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::F32) => "cast_u32_f32",
@ -280,8 +293,8 @@ impl BackendStorage for MetalStorage {
); );
} }
command_buffer.commit(); // command_buffer.commit();
command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled();
Ok(Self { Ok(Self {
buffer, buffer,
device: device.clone(), device: device.clone(),
@ -295,7 +308,7 @@ impl BackendStorage for MetalStorage {
let shape = layout.shape(); let shape = layout.shape();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer(); let command_buffer = device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
use candle_metal_kernels::unary::contiguous; use candle_metal_kernels::unary::contiguous;
@ -329,8 +342,8 @@ impl BackendStorage for MetalStorage {
} else { } else {
todo!("TODO Implement the kernel calling {}", B::KERNEL); todo!("TODO Implement the kernel calling {}", B::KERNEL);
} }
command_buffer.commit(); // command_buffer.commit();
command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled();
Ok(Self { Ok(Self {
buffer, buffer,
@ -350,7 +363,7 @@ impl BackendStorage for MetalStorage {
let shape = lhs_l.shape(); let shape = lhs_l.shape();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer(); let command_buffer = device.command_buffer();
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
{ {
@ -404,8 +417,8 @@ impl BackendStorage for MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
command_buffer.commit(); // command_buffer.commit();
command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled();
Ok(Self { Ok(Self {
buffer, buffer,
@ -428,7 +441,7 @@ impl BackendStorage for MetalStorage {
let el = shape.elem_count(); let el = shape.elem_count();
let dtype = t.dtype; let dtype = t.dtype;
let mut buffer = self.device.new_buffer(el, dtype); let mut buffer = self.device.new_buffer(el, dtype);
let command_buffer = self.device.command_queue.new_command_buffer(); let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_where_cond_strided( candle_metal_kernels::call_where_cond_strided(
&device.device, &device.device,
&command_buffer, &command_buffer,
@ -447,8 +460,8 @@ impl BackendStorage for MetalStorage {
&mut buffer, &mut buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
command_buffer.commit(); // command_buffer.commit();
command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled();
Ok(Self { Ok(Self {
buffer, buffer,
device, device,
@ -544,7 +557,7 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F32) => "is_u32_f32",
(left, right) => todo!("index select metal {left:?} {right:?}"), (left, right) => todo!("index select metal {left:?} {right:?}"),
}; };
let command_buffer = self.device.command_queue.new_command_buffer(); let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_index_select( candle_metal_kernels::call_index_select(
&device.device, &device.device,
&command_buffer, &command_buffer,
@ -558,8 +571,8 @@ impl BackendStorage for MetalStorage {
&mut buffer, &mut buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
command_buffer.commit(); // command_buffer.commit();
command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled();
Ok(Self { Ok(Self {
buffer, buffer,
device: device.clone(), device: device.clone(),
@ -641,7 +654,7 @@ impl BackendStorage for MetalStorage {
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
let out_buffer = self.device.new_buffer(elem_count, self.dtype); let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let command_buffer = self.device.command_queue.new_command_buffer(); let command_buffer = self.device.command_buffer();
for bi in 0..b { for bi in 0..b {
// Create matrix objects // Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor( let left_matrix = Matrix::init_with_buffer_descriptor(
@ -689,14 +702,14 @@ impl BackendStorage for MetalStorage {
// Encode kernel to command buffer // Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer( matrix_multiplication.encode_to_command_buffer(
command_buffer, command_buffer.deref(),
&left_matrix, &left_matrix,
&right_matrix, &right_matrix,
&result_matrix, &result_matrix,
); );
} }
command_buffer.commit(); // command_buffer.commit();
command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled();
Ok(Self { Ok(Self {
buffer: out_buffer, buffer: out_buffer,
@ -712,7 +725,7 @@ impl BackendStorage for MetalStorage {
if el_count == 0 { if el_count == 0 {
return Ok(()); return Ok(());
} }
let command_buffer = self.device.command_queue.new_command_buffer(); let command_buffer = self.device.command_buffer();
let kernel_name = match self.dtype { let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
@ -733,8 +746,8 @@ impl BackendStorage for MetalStorage {
dst_offset, dst_offset,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
command_buffer.commit(); // command_buffer.commit();
command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled();
Ok(()) Ok(())
} }
} }
@ -760,10 +773,12 @@ impl BackendDevice for MetalDevice {
let device = metal::Device::all().swap_remove(ordinal); let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer()));
let kernels = Arc::new(Kernels::new()); let kernels = Arc::new(Kernels::new());
Ok(Self { Ok(Self {
device, device,
command_queue, command_queue,
command_buffer,
kernels, kernels,
}) })
} }