diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 00301352..a317d1bf 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -8,6 +8,7 @@ use half::f16; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::sync::{Arc, RwLock}; +use std::collections::HashMap; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -37,6 +38,7 @@ pub struct MetalDevice { command_queue: metal::CommandQueue, command_buffer: Arc>, kernels: Arc, + buffers: Arc>>>>, } impl std::fmt::Debug for MetalDevice { @@ -87,8 +89,26 @@ impl MetalDevice { &self.device } - pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { - let size = (element_count * dtype.size_in_bytes()) as NSUInteger; + pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc{ + let size = element_count * dtype.size_in_bytes(); + let mut buffers = self.buffers.try_write().unwrap(); + let subbuffers = buffers.entry(size).or_insert(vec![]); + + for sub in &mut *subbuffers{ + // if sub.retain_count() == 1{ + // println!("{size} {:?}", ); + if Arc::strong_count(sub) == 1{ + return sub.clone(); + } + } + let new_buffer = self.device + .new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate); + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + new_buffer + } + + pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer { self.device .new_buffer(size, MTLResourceOptions::StorageModeManaged) } @@ -105,7 +125,7 @@ impl MetalDevice { #[derive(Debug, Clone)] pub struct MetalStorage { - buffer: metal::Buffer, + buffer: Arc, device: MetalDevice, dtype: DType, } @@ -126,29 +146,38 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { + let buffer = self.device.new_buffer_managed(self.buffer.length()); + { + let command = self.device.command_buffer(); + let blit = command.new_blit_command_encoder(); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + + } + self.device.wait_until_completed(); match self.dtype { DType::U8 => Ok(CpuStorage::U8( - self.buffer.read_to_vec(self.buffer.length() as usize), + buffer.read_to_vec(buffer.length() as usize), )), DType::U32 => Ok(CpuStorage::U32( - self.buffer.read_to_vec(self.buffer.length() as usize / 4), + buffer.read_to_vec(buffer.length() as usize / 4), )), DType::I64 => Ok(CpuStorage::I64( - self.buffer.read_to_vec(self.buffer.length() as usize / 8), + buffer.read_to_vec(buffer.length() as usize / 8), )), DType::F16 => Ok(CpuStorage::F16( - self.buffer.read_to_vec(self.buffer.length() as usize / 2), + buffer.read_to_vec(buffer.length() as usize / 2), )), DType::BF16 => Ok(CpuStorage::BF16( - self.buffer.read_to_vec(self.buffer.length() as usize / 2), + buffer.read_to_vec(buffer.length() as usize / 2), )), DType::F32 => Ok(CpuStorage::F32( - self.buffer.read_to_vec(self.buffer.length() as usize / 4), + buffer.read_to_vec(buffer.length() as usize / 4), )), DType::F64 => Ok(CpuStorage::F64( - self.buffer.read_to_vec(self.buffer.length() as usize / 8), + buffer.read_to_vec(buffer.length() as usize / 8), )), } } @@ -175,7 +204,7 @@ impl BackendStorage for MetalStorage { name, el, &self.buffer, - &mut buffer, + &buffer, mul as f32, add as f32, ) @@ -195,7 +224,7 @@ impl BackendStorage for MetalStorage { &self.buffer, layout.stride(), layout.start_offset() * dtype.size_in_bytes(), - &mut buffer, + &buffer, mul as f32, add as f32, ) @@ -270,7 +299,7 @@ impl BackendStorage for MetalStorage { dst_el, &self.buffer, layout.start_offset() * self.dtype.size_in_bytes(), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; @@ -305,7 +334,7 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } else { @@ -324,7 +353,7 @@ impl BackendStorage for MetalStorage { &self.buffer, layout.stride(), layout.start_offset() * self.dtype.size_in_bytes(), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } @@ -382,7 +411,7 @@ impl BackendStorage for MetalStorage { kernel_name, el_count, &self.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } else { @@ -425,7 +454,7 @@ impl BackendStorage for MetalStorage { &self.buffer, layout.stride(), layout.start_offset() * self.dtype.size_in_bytes(), - &mut buffer, + &buffer, 0, ) .map_err(MetalError::from)?; @@ -481,7 +510,7 @@ impl BackendStorage for MetalStorage { el_count, &self.buffer, &rhs.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } else { @@ -510,7 +539,7 @@ impl BackendStorage for MetalStorage { &rhs.buffer, rhs_l.stride(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; } @@ -551,7 +580,7 @@ impl BackendStorage for MetalStorage { (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), &f.buffer, (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; Ok(Self { @@ -661,7 +690,7 @@ impl BackendStorage for MetalStorage { dim, &self.buffer, &ids.buffer, - &mut buffer, + &buffer, ) .map_err(MetalError::from)?; Ok(Self { @@ -860,7 +889,7 @@ impl BackendStorage for MetalStorage { &self.buffer, src_l.stride(), src_l.start_offset() * self.dtype.size_in_bytes(), - &mut dst.buffer, + &dst.buffer, dst_offset * dst.dtype.size_in_bytes(), ) .map_err(MetalError::from)?; @@ -869,7 +898,7 @@ impl BackendStorage for MetalStorage { } impl MetalStorage { - pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self { + pub fn new(buffer: Arc, device: MetalDevice, dtype: DType) -> Self { Self { buffer, device, @@ -904,10 +933,12 @@ impl BackendDevice for MetalDevice { let command_queue = device.new_command_queue(); let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let kernels = Arc::new(Kernels::new()); + let buffers = Arc::new(RwLock::new(HashMap::new())); Ok(Self { device, command_queue, command_buffer, + buffers, kernels, }) } @@ -952,7 +983,7 @@ impl BackendDevice for MetalDevice { CpuStorage::F64(storage) => self.new_buffer_with_data(storage), }; Ok(Self::Storage { - buffer, + buffer: buffer.into(), device: self.clone(), dtype: storage.dtype(), }) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e4220286..fcf6930b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -298,7 +298,7 @@ pub fn call_unary_contiguous( kernel_name: unary::contiguous::Kernel, length: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); @@ -320,7 +320,7 @@ pub fn call_unary_strided( input: &Buffer, strides: &[usize], offset: usize, - output: &mut Buffer, + output: &Buffer, output_offset: usize, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; @@ -358,7 +358,7 @@ pub fn call_binary_contiguous( length: usize, left: &Buffer, right: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; @@ -386,7 +386,7 @@ pub fn call_binary_strided( right_input: &Buffer, right_strides: &[usize], right_offset: usize, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; @@ -425,7 +425,7 @@ pub fn call_cast_contiguous( kernel_name: &'static str, length: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; @@ -450,7 +450,7 @@ pub fn call_cast_strided( input: &Buffer, input_strides: &[usize], input_offset: usize, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { // println!("Kernel {:?}", kernel_name.0); // assert_eq!(input.length(), output.length()); @@ -482,7 +482,7 @@ pub fn call_reduce_contiguous( out_length: usize, input: &Buffer, input_offset: usize, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let elements_to_sum = length / out_length; @@ -523,7 +523,7 @@ pub fn call_last_softmax( length: usize, elements_to_sum: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); @@ -564,7 +564,7 @@ pub fn call_affine( name: &'static str, size: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { @@ -590,7 +590,7 @@ pub fn call_affine_strided( input: &Buffer, input_stride: &[usize], input_offset: usize, - output: &mut Buffer, + output: &Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { @@ -632,7 +632,7 @@ pub fn call_where_cond_strided( (left_stride, left_offset): (&[usize], usize), right: &Buffer, (right_stride, right_offset): (&[usize], usize), - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; @@ -675,7 +675,7 @@ pub fn call_index_select( dim: usize, input: &Buffer, ids: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); @@ -750,7 +750,7 @@ mod tests { name, v.len(), &input, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -775,7 +775,7 @@ mod tests { x.len(), &left, &right, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -805,7 +805,7 @@ mod tests { &input, strides, offset, - &mut output, + &output, 0, ) .unwrap(); @@ -943,7 +943,7 @@ mod tests { name, v.len(), &input, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -984,7 +984,7 @@ mod tests { "affine_float", size, &input, - &mut output, + &output, mul as f32, add as f32, ) @@ -1021,7 +1021,7 @@ mod tests { &input, strides, 0, - &mut output, + &output, mul as f32, add as f32, ) @@ -1119,7 +1119,7 @@ mod tests { dim, &embeddings_buffer, &ids_buffer, - &mut dst_buffer, + &dst_buffer, ) .unwrap(); @@ -1227,7 +1227,7 @@ mod tests { out_length, &input, 0, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -1255,7 +1255,7 @@ mod tests { v.len(), last_dim, &input, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -1355,7 +1355,7 @@ mod tests { (&left_stride, left_offset), &right, (&cond_stride, cond_offset), - &mut output, + &output, ) .unwrap(); command_buffer.commit();