From dd4a40f1c0094fc1dc62814038dba7ae36317539 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 13 Nov 2023 14:33:16 +0100 Subject: [PATCH] Fixes + cache compute_pipeline_state. --- candle-core/src/metal_backend.rs | 13 +- candle-metal-kernels/src/lib.rs | 184 +++++++++------------------ candle-metal-kernels/src/unary.metal | 4 +- 3 files changed, 68 insertions(+), 133 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index eca8e1fe..efce19c1 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -86,7 +86,6 @@ impl MetalDevice { pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - // debug!("Allocate 1 - buffer size {size}"); self.device .new_buffer(size, MTLResourceOptions::StorageModeManaged) } @@ -115,7 +114,9 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { + let start = std::time::Instant::now(); self.device.wait_until_completed(); + println!("Wait took {:?}", start.elapsed()); match self.dtype { DType::U8 => Ok(CpuStorage::U8( @@ -414,6 +415,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } + self.device.wait_until_completed(); Ok(Self { buffer, device: device.clone(), @@ -899,9 +901,12 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - // TODO Is there a faster way ? - let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?; - self.storage_from_cpu_storage(&cpu_storage) + let buffer = self.new_buffer(shape.elem_count(), dtype); + Ok(MetalStorage { + buffer, + device: self.clone(), + dtype, + }) } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index afbcbff7..6b2ab050 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,7 +1,7 @@ #![allow(clippy::too_many_arguments)] use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, - ComputePipelineState, Device, Function, Library, MTLSize, + Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, + Device, Function, Library, MTLSize, }; use std::collections::HashMap; use std::ffi::c_void; @@ -174,6 +174,10 @@ pub enum MetalKernelError { LoadLibraryError(String), #[error("Error while loading function: {0:?}")] LoadFunctionError(String), + #[error("Failed to create compute function")] + FailedToCreateComputeFunction, + #[error("Failed to create pipeline")] + FailedToCreatePipeline(String), } impl From> for MetalKernelError { @@ -184,19 +188,22 @@ impl From> for MetalKernelError { type KernelMap = HashMap<&'static str, T>; type Libraries = HashMap; -type Functions = KernelMap; +type Pipelines = KernelMap; -#[derive(Debug, Default)] +#[derive(Debug)] pub struct Kernels { libraries: RwLock, - funcs: RwLock, + pipelines: RwLock, } impl Kernels { pub fn new() -> Self { let libraries = RwLock::new(Libraries::new()); - let funcs = RwLock::new(Functions::new()); - Self { libraries, funcs } + let pipelines = RwLock::new(Pipelines::new()); + Self { + libraries, + pipelines, + } } // pub fn init(device: &Device) -> Result { @@ -243,22 +250,43 @@ impl Kernels { } } - pub fn load_function( + fn load_function( &self, device: &Device, source: Source, name: &'static str, ) -> Result { - let mut funcs = self.funcs.write()?; - if let Some(func) = funcs.get(name) { - Ok(func.clone()) + let func = self + .load_library(device, source)? + .get_function(name, None) + .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; + Ok(func) + // let mut funcs = self.funcs.write()?; + // if let Some(func) = funcs.get(name) { + // Ok(func.clone()) + // } else { + // funcs.insert(name, func.clone()); + // Ok(func) + // } + } + + pub fn load_pipeline( + &self, + device: &Device, + source: Source, + name: &'static str, + ) -> Result { + let mut pipelines = self.pipelines.write()?; + if let Some(pipeline) = pipelines.get(name) { + Ok(pipeline.clone()) } else { - let func = self - .load_library(device, source)? - .get_function(name, None) - .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; - funcs.insert(name, func.clone()); - Ok(func) + let func = self.load_function(device, source, name)?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; + pipelines.insert(name, pipeline.clone()); + + Ok(pipeline) } } } @@ -274,16 +302,8 @@ pub fn call_unary_contiguous( ) -> Result<(), MetalKernelError> { // println!("Kernel {:?}", kernel_name.0); // assert_eq!(input.length(), output.length()); - let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -306,15 +326,7 @@ pub fn call_unary_strided( output: &mut Buffer, output_offset: usize, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Unary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); @@ -351,17 +363,7 @@ pub fn call_binary_contiguous( right: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); - let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -389,15 +391,7 @@ pub fn call_binary_strided( right_offset: usize, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Binary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); @@ -436,17 +430,7 @@ pub fn call_cast_contiguous( input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); - let func = kernels.load_function(device, Source::Cast, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -473,15 +457,7 @@ pub fn call_cast_strided( ) -> Result<(), MetalKernelError> { // println!("Kernel {:?}", kernel_name.0); // assert_eq!(input.length(), output.length()); - let func = kernels.load_function(device, Source::Cast, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -510,16 +486,7 @@ pub fn call_reduce_contiguous( input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); @@ -560,16 +527,7 @@ pub fn call_last_softmax( input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -612,15 +570,7 @@ pub fn call_affine( mul: f32, add: f32, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Affine, name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -646,18 +596,9 @@ pub fn call_affine_strided( mul: f32, add: f32, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Affine, name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let size: usize = shape.iter().product(); - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -695,15 +636,7 @@ pub fn call_where_cond_strided( (right_stride, right_offset): (&[usize], usize), output: &mut Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Ternary, name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -751,10 +684,7 @@ pub fn call_index_select( let src_dim_size = shape[dim]; let dst_el = ids_size * left_size * right_size; - let func = kernels.load_function(device, Source::Indexing, name)?; - let pipeline = device - .new_compute_pipeline_state_with_function(&func) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 9f614b30..5389a26b 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -40,7 +40,7 @@ template METAL_FUNC T erf(T in){ float t = 1.0/(1.0 + p*x); float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x); - return (T) sign*y; + return T(sign*y); } template METAL_FUNC T id(T in){ return in; } template METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; } @@ -49,7 +49,7 @@ template METAL_FUNC T gelu(T x){ T x_cube = x_sq * x; T alpha = x + static_cast(0.044715) * x_cube; T beta = (static_cast(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); - return static_cast(0.5) * x * (static_cast(1.0) + tanh(beta)); + return static_cast(0.5) * x * (static_cast(1.0) + T(tanh(beta))); }