diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index b24db020..d8518b3e 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -34,12 +34,48 @@ impl From for MetalError { #[derive(Clone)] pub struct MetalDevice { + /// Raw metal device: device: metal::Device, + + /// Single command queue for the entire device. command_queue: metal::CommandQueue, - command_buffers: Arc>>, + /// One command buffer at a time. + /// The scheduler works by allowing multiple + /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) + /// on a single command buffer. Using a single command buffer would be fastest on the GPU but + /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed + /// to start to work). + /// Despite what the documentation says, command buffers are NOT ordered. They are ordered + /// for their START time, but there's no guarantee that command buffer1 will finish before + /// command buffer2 starts (or there are metal bugs there) + command_buffer: Arc>, + /// Keeps track of the current amount of compute command encoders on the current + /// command buffer + /// Arc, RwLock because of the interior mutability. command_buffer_index: Arc>, + /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) + compute_per_buffer: usize, + /// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the + /// execution order to be linear. + /// It could be relaxed in some circumstances, by managing ourselves the dependencies in the + /// compute graph. fence: metal::Fence, + /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. + /// Heavily used by [`candle_metal_kernels`], both fences need to match kernels: Arc, + /// Simple allocator struct. + /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. + /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting + /// (could be linked to FFI communication overhead). + /// + /// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the + /// graph calculation, and only we the allocator kept a reference to it, therefore it's free + /// to be reused. However, in order for this to work, we need to guarantee the order of + /// operation, so that this buffer is not being used by another kernel at the same time. + /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things. + /// + /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers + /// (strong_count = 1). buffers: Arc>>>>, } @@ -71,13 +107,13 @@ impl MetalDevice { } pub fn command_buffer(&self) -> CommandBuffer { - let mut command_buffers = self.command_buffers.try_write().unwrap(); - let mut command_buffer = command_buffers[0].to_owned(); + let mut command_buffer_lock = self.command_buffer.try_write().unwrap(); + let mut command_buffer = command_buffer_lock.to_owned(); let mut index = self.command_buffer_index.try_write().unwrap(); - if *index > 20 { + if *index > self.compute_per_buffer { command_buffer.commit(); command_buffer = self.command_queue.new_command_buffer().to_owned(); - *command_buffers = vec![command_buffer.clone()]; + *command_buffer_lock = command_buffer.clone(); *index = 0; } *index += 1; @@ -85,8 +121,7 @@ impl MetalDevice { } pub fn wait_until_completed(&self) { - let mut command_buffers = self.command_buffers.try_write().unwrap(); - let command_buffer = &command_buffers[0]; + let mut command_buffer = self.command_buffer.try_write().unwrap(); match command_buffer.status() { metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled @@ -97,7 +132,7 @@ impl MetalDevice { } command_buffer.commit(); command_buffer.wait_until_completed(); - *command_buffers = vec![self.command_queue.new_command_buffer().to_owned()]; + *command_buffer = self.command_queue.new_command_buffer().to_owned(); } pub fn kernels(&self) -> &Kernels { @@ -108,12 +143,65 @@ impl MetalDevice { &self.device } + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer data cannot be read on the CPU directly. + /// + /// [`name`] is only used to keep track of the resource origin in case of bugs pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self._new_buffer(size, MTLResourceOptions::StorageModePrivate, name) + self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) } - fn _new_buffer( + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer can be read on the CPU but will require manual + /// synchronization when the CPU memory is modified + /// Used as a bridge to gather data back from the GPU + pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { + self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") + } + + /// Creates a new buffer from data. + /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// + /// This method will block the computation because of the + /// lack of lifetime management through the GPU. + /// Internal comment for technical details. + pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { + let size = core::mem::size_of_val(data) as NSUInteger; + let tmp = self.device.new_buffer_with_data( + data.as_ptr() as *const core::ffi::c_void, + size, + metal::MTLResourceOptions::StorageModeManaged, + ); + let real = self.allocate_buffer( + size, + metal::MTLResourceOptions::StorageModePrivate, + "with_data", + ); + let command_buffer = self.command_buffer(); + command_buffer.set_label("with_data"); + let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&self.fence); + blit.set_label("with_data_blit"); + blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); + blit.update_fence(&self.fence); + blit.end_encoding(); + + // This is necessary, for mmaped safetensors + // Because of the unsafe slice cast we're doing. + // The slice might not live long enough for metal + // To actually fill the GPU buffer. + // Putting this wait forces the GPU buffer to be filled + // with the actual data allowing the CPU storage todo + // deallocate properly. + self.wait_until_completed(); + real + } + + /// The critical allocator algorithm + fn allocate_buffer( &self, size: NSUInteger, option: MTLResourceOptions, @@ -142,42 +230,7 @@ impl MetalDevice { new_buffer } - pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc { - self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") - } - - pub fn new_buffer_with_data(&self, data: &[T]) -> Arc { - let size = core::mem::size_of_val(data) as NSUInteger; - let tmp = self.device.new_buffer_with_data( - data.as_ptr() as *const core::ffi::c_void, - size, - metal::MTLResourceOptions::StorageModeManaged, - ); - let real = self._new_buffer( - size, - metal::MTLResourceOptions::StorageModePrivate, - "with_data", - ); - let command_buffer = self.command_buffer(); - command_buffer.set_label("with_data"); - let blit = command_buffer.new_blit_command_encoder(); - blit.wait_for_fence(&self.fence); - blit.set_label("with_data_blit"); - blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); - blit.update_fence(&self.fence); - blit.end_encoding(); - - // This is necessary, for mmaped safetensors - // Because of the unsafe slice cast we're doing. - // The slice might not live long enough for metal - // To actually fill the GPU buffer. - // Putting this wait forces the GPU buffer to be filled - // with the actual data allowing the CPU storage todo - // deallocate properly. - self.wait_until_completed(); - real - } - + /// Create a metal GPU capture trace on [`path`]. pub fn capture>(&self, path: P) -> Result<()> { let capture = metal::CaptureManager::shared(); let descriptor = metal::CaptureDescriptor::new(); @@ -194,8 +247,11 @@ impl MetalDevice { #[derive(Debug, Clone)] pub struct MetalStorage { + /// The actual buffer containing the data. buffer: Arc, + /// a reference to the device owning this buffer device: MetalDevice, + /// The dtype is kept since buffers are untyped. dtype: DType, } @@ -952,29 +1008,25 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result { let device = metal::Device::all().swap_remove(ordinal); - - let n = 1; let command_queue = device.new_command_queue(); - - let command_buffers = (0..n) - .map(|i| { - let command_buffer = command_queue.new_command_buffer().to_owned(); - command_buffer.enqueue(); - command_buffer.set_label(&format!("num {i}")); - command_buffer - }) - .collect(); - let command_buffers = Arc::new(RwLock::new(command_buffers)); + let command_buffer = command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + let command_buffer = Arc::new(RwLock::new(command_buffer)); let command_buffer_index = Arc::new(RwLock::new(0)); let fence = device.new_fence(); let kernels = Arc::new(Kernels::new(fence.clone())); let buffers = Arc::new(RwLock::new(HashMap::new())); + let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { + Ok(val) => val.parse()?, + _ => 20, + }; Ok(Self { device, fence, command_queue, - command_buffers, + command_buffer, command_buffer_index, + compute_per_buffer, buffers, kernels, }) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 60f9b8a6..2fa571bc 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -15,6 +15,10 @@ const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); +/// Most kernels apply similarly across the tensors +/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the +/// actual total buffer length). +/// Then kernels can just do their op on their single point in the buffer. fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); @@ -36,6 +40,10 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {

::set_param(encoder, position, data) } + +/// Helper functions to create the various objects on the compute command encoder +/// on a single line. +/// Prevents getting wrong some arguments number and mixing length and size in bytes. trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } @@ -220,6 +228,9 @@ impl Kernels { Source::Mfa => panic!("Invalid lib"), } } + + /// Load the give library from its [`source`]. + /// If this has been previously loaded it will just fetch it from cache. pub fn load_library( &self, device: &Device, @@ -262,6 +273,9 @@ impl Kernels { Ok(func) } + /// Load the give pipeline + /// loads the library from source, then gets the function [`name`] from + /// that source fn load_pipeline_with_constants( &self, device: &Device, @@ -290,6 +304,9 @@ impl Kernels { } } + /// Load the give pipeline + /// loads the library from source, then gets the function [`name`] from + /// that source (without constants) pub fn load_pipeline( &self, device: &Device,