diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index c48ad5c7..f363a84b 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -7,8 +7,9 @@ use candle_metal_kernels::Kernels; use core::mem; use half::{bf16, f16}; use metal; -use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger}; -use std::sync::Arc; +use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use std::ops::Deref; +use std::sync::{Arc, RwLock}; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -36,6 +37,7 @@ impl From for MetalError { pub struct MetalDevice { device: metal::Device, command_queue: metal::CommandQueue, + command_buffer: Arc>, kernels: Arc, } @@ -54,10 +56,6 @@ impl std::ops::Deref for MetalDevice { } impl MetalDevice { - // pub fn metal_device(&self) -> &metal::DeviceRef { - // self.device.as_ref() - // } - pub fn id(&self) -> NSUInteger { self.registry_id() } @@ -66,6 +64,19 @@ impl MetalDevice { &self.command_queue } + pub fn command_buffer(&self) -> std::sync::RwLockReadGuard { + self.command_buffer.read().unwrap() + } + + pub fn wait_until_completed(&self) { + let mut old = self.command_buffer.write().unwrap(); + old.commit(); + old.wait_until_completed(); + let command_buffer = self.command_queue.new_owned_command_buffer(); + *old = command_buffer; + // self.command_buffer.replace_with(|_| command_buffer) + } + pub fn kernels(&self) -> &Kernels { &self.kernels } @@ -105,6 +116,8 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { + self.device.wait_until_completed(); + match self.dtype { DType::U8 => Ok(CpuStorage::U8( self.buffer.read_to_vec(self.buffer.length() as usize / 1), @@ -138,7 +151,7 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let mut buffer = device.new_buffer(el, self.dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { assert_eq!(dtype, DType::F32); candle_metal_kernels::call_affine( @@ -168,8 +181,8 @@ impl BackendStorage for MetalStorage { ) .unwrap(); } - command_buffer.commit(); - command_buffer.wait_until_completed(); + // command_buffer.commit(); + // command_buffer.wait_until_scheduled(); return Ok(Self { buffer, device: device.clone(), @@ -225,7 +238,7 @@ impl BackendStorage for MetalStorage { } let dtype = if return_index { DType::U32 } else { self.dtype }; let mut buffer = device.new_buffer(dst_el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_reduce_contiguous( &device.device, &command_buffer, @@ -237,8 +250,8 @@ impl BackendStorage for MetalStorage { &mut buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); + // command_buffer.commit(); + // command_buffer.wait_until_scheduled(); Ok(Self { buffer, @@ -256,7 +269,7 @@ impl BackendStorage for MetalStorage { let shape = layout.shape(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let command_buffer = device.command_buffer(); if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", @@ -280,8 +293,8 @@ impl BackendStorage for MetalStorage { ); } - command_buffer.commit(); - command_buffer.wait_until_completed(); + // command_buffer.commit(); + // command_buffer.wait_until_scheduled(); Ok(Self { buffer, device: device.clone(), @@ -295,7 +308,7 @@ impl BackendStorage for MetalStorage { let shape = layout.shape(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let command_buffer = device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; @@ -329,8 +342,8 @@ impl BackendStorage for MetalStorage { } else { todo!("TODO Implement the kernel calling {}", B::KERNEL); } - command_buffer.commit(); - command_buffer.wait_until_completed(); + // command_buffer.commit(); + // command_buffer.wait_until_scheduled(); Ok(Self { buffer, @@ -350,7 +363,7 @@ impl BackendStorage for MetalStorage { let shape = lhs_l.shape(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_queue.new_command_buffer(); + let command_buffer = device.command_buffer(); if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) { @@ -404,8 +417,8 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - command_buffer.commit(); - command_buffer.wait_until_completed(); + // command_buffer.commit(); + // command_buffer.wait_until_scheduled(); Ok(Self { buffer, @@ -428,7 +441,7 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = t.dtype; let mut buffer = self.device.new_buffer(el, dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_where_cond_strided( &device.device, &command_buffer, @@ -447,8 +460,8 @@ impl BackendStorage for MetalStorage { &mut buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); + // command_buffer.commit(); + // command_buffer.wait_until_scheduled(); Ok(Self { buffer, device, @@ -544,7 +557,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "is_u32_f32", (left, right) => todo!("index select metal {left:?} {right:?}"), }; - let command_buffer = self.device.command_queue.new_command_buffer(); + let command_buffer = self.device.command_buffer(); candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -558,8 +571,8 @@ impl BackendStorage for MetalStorage { &mut buffer, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); + // command_buffer.commit(); + // command_buffer.wait_until_scheduled(); Ok(Self { buffer, device: device.clone(), @@ -641,7 +654,7 @@ impl BackendStorage for MetalStorage { let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); let out_buffer = self.device.new_buffer(elem_count, self.dtype); - let command_buffer = self.device.command_queue.new_command_buffer(); + let command_buffer = self.device.command_buffer(); for bi in 0..b { // Create matrix objects let left_matrix = Matrix::init_with_buffer_descriptor( @@ -689,14 +702,14 @@ impl BackendStorage for MetalStorage { // Encode kernel to command buffer matrix_multiplication.encode_to_command_buffer( - command_buffer, + command_buffer.deref(), &left_matrix, &right_matrix, &result_matrix, ); } - command_buffer.commit(); - command_buffer.wait_until_completed(); + // command_buffer.commit(); + // command_buffer.wait_until_scheduled(); Ok(Self { buffer: out_buffer, @@ -712,7 +725,7 @@ impl BackendStorage for MetalStorage { if el_count == 0 { return Ok(()); } - let command_buffer = self.device.command_queue.new_command_buffer(); + let command_buffer = self.device.command_buffer(); let kernel_name = match self.dtype { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, @@ -733,8 +746,8 @@ impl BackendStorage for MetalStorage { dst_offset, ) .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_completed(); + // command_buffer.commit(); + // command_buffer.wait_until_scheduled(); Ok(()) } } @@ -760,10 +773,12 @@ impl BackendDevice for MetalDevice { let device = metal::Device::all().swap_remove(ordinal); let command_queue = device.new_command_queue(); + let command_buffer = Arc::new(RwLock::new(command_queue.new_owned_command_buffer())); let kernels = Arc::new(Kernels::new()); Ok(Self { device, command_queue, + command_buffer, kernels, }) }