diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 7a979086..6e1ecc5e 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -9,7 +9,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger} use std::collections::HashMap; use std::ffi::c_void; use std::path::Path; -use std::sync::{Arc, Mutex, RwLock, TryLockError}; +use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError}; /// Simple way to catch lock error without /// depending on T @@ -60,7 +60,8 @@ impl From for MetalError { } } -type AllocatedBuffers = Arc>>>>; +type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec>>; +type AllocatedBuffers = Arc>; #[derive(Clone)] pub struct MetalDevice { @@ -68,7 +69,7 @@ pub struct MetalDevice { device: metal::Device, /// Single command queue for the entire device. - command_queue: metal::CommandQueue, + command_queue: CommandQueue, /// One command buffer at a time. /// The scheduler works by allowing multiple /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) @@ -78,7 +79,7 @@ pub struct MetalDevice { /// Despite what the documentation says, command buffers are NOT ordered. They are ordered /// for their START time, but there's no guarantee that command buffer1 will finish before /// command buffer2 starts (or there are metal bugs there) - command_buffer: Arc>, + command_buffer: Arc>, /// Keeps track of the current amount of compute command encoders on the current /// command buffer /// Arc, RwLock because of the interior mutability. @@ -87,7 +88,7 @@ pub struct MetalDevice { compute_per_buffer: usize, /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. /// Heavily used by [`candle_metal_kernels`] - kernels: Arc, + kernels: Arc, /// Simple allocator struct. /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting @@ -99,7 +100,7 @@ pub struct MetalDevice { /// operation, so that this buffer is not being used by another kernel at the same time. /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things. /// - /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers + /// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers /// (strong_count = 1). buffers: AllocatedBuffers, /// Seed for random number generation. @@ -145,6 +146,8 @@ impl MetalDevice { command_buffer = self.command_queue.new_command_buffer().to_owned(); *command_buffer_lock = command_buffer.clone(); *index = 0; + + self.drop_unused_buffers()?; } *index += 1; Ok(command_buffer) @@ -163,6 +166,7 @@ impl MetalDevice { command_buffer.commit(); command_buffer.wait_until_completed(); *command_buffer = self.command_queue.new_command_buffer().to_owned(); + Ok(()) } @@ -199,39 +203,25 @@ impl MetalDevice { } /// Creates a new buffer from data. - /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) /// - /// This method will block the computation because of the - /// lack of lifetime management through the GPU. - /// Internal comment for technical details. + /// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes) + /// allocates the buffer and copies over the existing data before returning the MTLBuffer. pub fn new_buffer_with_data(&self, data: &[T]) -> Result> { let size = core::mem::size_of_val(data) as NSUInteger; - let tmp = self.device.new_buffer_with_data( - data.as_ptr() as *const core::ffi::c_void, + let new_buffer = self.device.new_buffer_with_data( + data.as_ptr() as *const c_void, size, - metal::MTLResourceOptions::StorageModeManaged, + MTLResourceOptions::StorageModeManaged, ); - let real = self.allocate_buffer( - size, - metal::MTLResourceOptions::StorageModePrivate, - "with_data", - )?; - let command_buffer = self.command_buffer()?; - command_buffer.set_label("with_data"); - let blit = command_buffer.new_blit_command_encoder(); - blit.set_label("with_data_blit"); - blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); - blit.end_encoding(); + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + let subbuffers = buffers + .entry((size, MTLResourceOptions::StorageModeManaged)) + .or_insert(vec![]); - // This is necessary, for mmaped safetensors - // Because of the unsafe slice cast we're doing. - // The slice might not live long enough for metal - // To actually fill the GPU buffer. - // Putting this wait forces the GPU buffer to be filled - // with the actual data allowing the CPU storage to do - // deallocate properly. - self.wait_until_completed()?; - Ok(real) + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + Ok(new_buffer) } pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result> { @@ -255,6 +245,40 @@ impl MetalDevice { Ok(buffer) } + fn find_available_buffer( + &self, + size: NSUInteger, + option: MTLResourceOptions, + buffers: &RwLockWriteGuard, + ) -> Option> { + let mut best_buffer: Option<&Arc> = None; + let mut best_buffer_size: NSUInteger = NSUInteger::MAX; + for ((buffer_size, buffer_option), subbuffers) in buffers.iter() { + if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option { + for sub in subbuffers { + if Arc::strong_count(sub) == 1 { + best_buffer = Some(sub); + best_buffer_size = *buffer_size; + } + } + } + } + return best_buffer.map(|b| b.clone()); + } + + fn drop_unused_buffers(&self) -> Result<()> { + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + for subbuffers in buffers.values_mut() { + let newbuffers = subbuffers + .iter() + .filter(|s| Arc::strong_count(*s) > 1) + .map(Arc::clone) + .collect(); + *subbuffers = newbuffers; + } + Ok(()) + } + /// The critical allocator algorithm fn allocate_buffer( &self, @@ -263,24 +287,18 @@ impl MetalDevice { _name: &str, ) -> Result> { let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + if let Some(b) = self.find_available_buffer(size, option, &buffers) { + // Cloning also ensures we increment the strong count + return Ok(b.clone()); + } + + let size = buf_size(size); let subbuffers = buffers.entry((size, option)).or_insert(vec![]); - for sub in &mut *subbuffers { - if Arc::strong_count(sub) == 1 { - return Ok(sub.clone()); - } - } let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = Arc::new(new_buffer); subbuffers.push(new_buffer.clone()); - for subbuffers in buffers.values_mut() { - let newbuffers = subbuffers - .iter() - .filter(|s| Arc::strong_count(s) > 1) - .map(Arc::clone) - .collect(); - *subbuffers = newbuffers; - } + Ok(new_buffer) } @@ -305,6 +323,8 @@ pub struct MetalStorage { buffer: Arc, /// a reference to the device owning this buffer device: MetalDevice, + /// The count of allocated elements in the buffer + count: usize, /// The dtype is kept since buffers are untyped. dtype: DType, } @@ -386,7 +406,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - Ok(Self::new(buffer, device.clone(), dtype)) + Ok(Self::new(buffer, device.clone(), el, dtype)) } fn powf(&self, layout: &Layout, pow: f64) -> Result { @@ -435,7 +455,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - Ok(Self::new(buffer, device.clone(), dtype)) + Ok(Self::new(buffer, device.clone(), el, dtype)) } fn elu(&self, layout: &Layout, alpha: f64) -> Result { @@ -484,7 +504,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - Ok(Self::new(buffer, device.clone(), dtype)) + Ok(Self::new(buffer, device.clone(), el, dtype)) } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { @@ -562,7 +582,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; - Ok(Self::new(buffer, device, dtype)) + Ok(Self::new(buffer, device, dst_el, dtype)) } fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { @@ -654,7 +674,7 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.set_label("to_dtype"); - Ok(Self::new(buffer, device.clone(), dtype)) + Ok(Self::new(buffer, device.clone(), el_count, dtype)) } fn unary_impl(&self, layout: &Layout) -> Result { @@ -774,7 +794,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - Ok(Self::new(buffer, device.clone(), dtype)) + Ok(Self::new(buffer, device.clone(), el_count, dtype)) } fn binary_impl( @@ -835,7 +855,7 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - Ok(Self::new(buffer, device, dtype)) + Ok(Self::new(buffer, device, el, dtype)) } fn conv1d( @@ -880,6 +900,7 @@ impl BackendStorage for MetalStorage { let col = Self { buffer: dst, device, + count: dst_el, dtype: self.dtype, }; let l_out = params.l_out(); @@ -964,6 +985,7 @@ impl BackendStorage for MetalStorage { let col = Self { buffer: dst, device, + count: dst_el, dtype: self.dtype, }; let h_out = params.out_h(); @@ -1049,7 +1071,7 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - Ok(Self::new(buffer, self.device.clone(), self.dtype)) + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) } fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { @@ -1083,7 +1105,7 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - Ok(Self::new(buffer, device.clone(), dtype)) + Ok(Self::new(buffer, device.clone(), dst_el, dtype)) } fn scatter_add( @@ -1172,7 +1194,7 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - Ok(Self::new(buffer, device.clone(), dtype)) + Ok(Self::new(buffer, device.clone(), dst_el, dtype)) } fn index_add( @@ -1254,7 +1276,12 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - Ok(Self::new(buffer, self.device.clone(), self.dtype())) + Ok(Self::new( + buffer, + self.device.clone(), + b * m * n, + self.dtype(), + )) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { @@ -1303,10 +1330,11 @@ impl BackendStorage for MetalStorage { } impl MetalStorage { - pub fn new(buffer: Arc, device: MetalDevice, dtype: DType) -> Self { + pub fn new(buffer: Arc, device: MetalDevice, count: usize, dtype: DType) -> Self { Self { buffer, device, + count, dtype, } } @@ -1521,29 +1549,23 @@ impl MetalStorage { (buffer, dtype) }; command_buffer.set_label("binary"); - Ok(Self::new(buffer, device.clone(), dtype)) + Ok(Self::new(buffer, device.clone(), el_count, dtype)) } pub(crate) fn to_cpu(&self) -> Result> { - let length = self.buffer.length() as usize; - let size = self.dtype.size_in_bytes(); - if length % size != 0 { - crate::bail!( - "The Metal buffer length is not aligned with dtype {:?}", - self.dtype - ); - } - let buffer = self.device.new_buffer_managed(self.buffer.length())?; + let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger; + + let buffer = self.device.new_buffer_managed(size)?; { let command_buffer = self.device.command_buffer()?; command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); - blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size); blit.end_encoding(); } self.device.wait_until_completed()?; - Ok(read_to_vec(&buffer, length / size)) + Ok(read_to_vec(&buffer, self.count)) } } @@ -1561,7 +1583,7 @@ impl BackendDevice for MetalDevice { let buffers = Arc::new(RwLock::new(HashMap::new())); let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { Ok(val) => val.parse()?, - _ => 10, + _ => 50, }; let seed = Arc::new(Mutex::new(device.new_buffer_with_data( [299792458].as_ptr() as *const c_void, @@ -1593,7 +1615,12 @@ impl BackendDevice for MetalDevice { fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { let size = shape.elem_count() * dtype.size_in_bytes(); let buffer = self.allocate_zeros(size)?; - Ok(MetalStorage::new(buffer, self.clone(), dtype)) + Ok(MetalStorage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { @@ -1603,16 +1630,21 @@ impl BackendDevice for MetalDevice { } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { - let buffer = match storage { - CpuStorage::U8(storage) => self.new_buffer_with_data(storage), - CpuStorage::U32(storage) => self.new_buffer_with_data(storage), - CpuStorage::I64(storage) => self.new_buffer_with_data(storage), - CpuStorage::BF16(storage) => self.new_buffer_with_data(storage), - CpuStorage::F16(storage) => self.new_buffer_with_data(storage), - CpuStorage::F32(storage) => self.new_buffer_with_data(storage), - CpuStorage::F64(storage) => self.new_buffer_with_data(storage), - }?; - Ok(Self::Storage::new(buffer, self.clone(), storage.dtype())) + let (count, buffer) = match storage { + CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + }; + Ok(Self::Storage::new( + buffer?, + self.clone(), + count, + storage.dtype(), + )) } fn rand_uniform( @@ -1643,7 +1675,12 @@ impl BackendDevice for MetalDevice { ) .map_err(MetalError::from)?; - Ok(Self::Storage::new(buffer, self.clone(), dtype)) + Ok(Self::Storage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) } fn rand_normal( @@ -1674,7 +1711,12 @@ impl BackendDevice for MetalDevice { ) .map_err(MetalError::from)?; - Ok(Self::Storage::new(buffer, self.clone(), dtype)) + Ok(Self::Storage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) } fn set_seed(&self, seed: u64) -> Result<()> { @@ -1693,6 +1735,10 @@ impl BackendDevice for MetalDevice { } } +fn buf_size(size: NSUInteger) -> NSUInteger { + (size - 1).next_power_of_two() as NSUInteger +} + fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { let ptr = buffer.contents() as *const T; assert!(!ptr.is_null()); diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index af1cf369..7be0f74e 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -106,7 +106,12 @@ impl QMetalStorage { } let buffer = self.device.new_buffer_with_data(&out)?; - Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32)) + Ok(MetalStorage::new( + buffer, + self.device.clone(), + elem_count, + DType::F32, + )) } pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> { @@ -170,7 +175,7 @@ impl QMetalStorage { &dst, ) .map_err(MetalError::from)?; - let dst_storage = crate::MetalStorage::new(dst, device, DType::F32); + let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32); Ok((dst_storage, dst_shape)) } } diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index aaec8b56..fdd67142 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -238,7 +238,8 @@ impl candle::CustomOp1 for SoftmaxLastDim { &output, ) .unwrap(); - let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); + let newstorage = + candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype()); Ok((newstorage, layout.shape().clone())) } }