diff --git a/candle-metal-kernels/build.rs b/candle-metal-kernels/build.rs new file mode 100644 index 00000000..f8e46e91 --- /dev/null +++ b/candle-metal-kernels/build.rs @@ -0,0 +1,123 @@ +use std::path::PathBuf; +use std::process::Command; +use std::{env, str}; + +const METAL_SOURCES: [&str; 1] = ["reduce"]; + +enum Platform { + MacOS, + IOS, +} + +impl Platform { + fn sdk(&self) -> &str { + match self { + Platform::MacOS => "macosx", + Platform::IOS => "iphoneos", + } + } +} + +fn compile(platform: Platform) -> Result<(), String> { + println!("cargo::rerun-if-changed=src/reduce.metal"); + println!("cargo::rerun-if-changed=src/utils.metal"); + println!("cargo::rerun-if-changed=build.rs"); + + let current_dir = env::current_dir().expect("Failed to get current directory"); + let out_dir = PathBuf::from(std::env::var("OUT_DIR").map_err(|_| "OUT_DIR not set")?); + let working_directory = out_dir.to_string_lossy().to_string(); + let sources = current_dir.join("src"); + + // Compile metal to air + let mut compile_air_cmd = Command::new("xcrun"); + compile_air_cmd + .arg("--sdk") + .arg(platform.sdk()) + .arg("metal") + .arg(format!("-working-directory={working_directory}")) + .arg("-Wall") + .arg("-Wextra") + .arg("-O3") + .arg("-c") + .arg("-w"); + for metal_file in METAL_SOURCES { + compile_air_cmd.arg(sources.join(format!("{metal_file}.metal"))); + } + compile_air_cmd.arg(sources.join("utils.metal")); + compile_air_cmd.spawn().expect("Failed to compile air"); + + let mut child = compile_air_cmd.spawn().expect("Failed to compile air"); + + match child.try_wait() { + Ok(Some(status)) => { + if !status.success() { + panic!( + "Compiling metal -> air failed. Exit with status: {}", + status + ) + } + } + Ok(None) => { + let status = child + .wait() + .expect("Compiling metal -> air failed while waiting for result"); + if !status.success() { + panic!( + "Compiling metal -> air failed. Exit with status: {}", + status + ) + } + } + Err(e) => panic!("Compiling metal -> air failed: {:?}", e), + } + + // Compile air to metallib + let lib_name = match platform { + Platform::MacOS => "candle.metallib", + Platform::IOS => "candle_ios.metallib", + }; + let metallib = out_dir.join(lib_name); + let mut compile_metallib_cmd = Command::new("xcrun"); + compile_metallib_cmd.arg("metal").arg("-o").arg(&metallib); + + for metal_file in METAL_SOURCES { + compile_metallib_cmd.arg(out_dir.join(format!("{metal_file}.air"))); + } + compile_metallib_cmd.arg(out_dir.join("utils.air")); + + let mut child = compile_metallib_cmd + .spawn() + .expect("Failed to compile air -> metallib"); + + match child.try_wait() { + Ok(Some(status)) => { + if !status.success() { + panic!( + "Compiling air -> metallib failed. Exit with status: {}", + status + ) + } + } + Ok(None) => { + let status = child + .wait() + .expect("Compiling air -> metallib failed while waiting for result"); + if !status.success() { + panic!( + "Compiling air -> metallib failed. Exit with status: {}", + status + ) + } + } + Err(e) => panic!("Compiling air -> metallib failed: {:?}", e), + } + + Ok(()) +} + +fn main() -> Result<(), String> { + compile(Platform::MacOS)?; + compile(Platform::IOS)?; + + Ok(()) +} diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 6de44f9c..698a25b3 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -13,6 +13,11 @@ pub use sort::{call_arg_sort, call_mlx_arg_sort}; pub use utils::BufferOffset; use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider}; +#[cfg(target_os = "macos")] +const CANDLE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/candle.metallib")); +#[cfg(target_os = "ios")] +const CANDLE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/candle_ios.metallib")); + const AFFINE: &str = include_str!("affine.metal"); const BINARY: &str = include_str!("binary.metal"); const CAST: &str = include_str!("cast.metal"); @@ -23,7 +28,6 @@ const MLX_GEMM: &str = include_str!("mlx_gemm.metal"); const MLX_SORT: &str = include_str!("mlx_sort.metal"); const QUANTIZED: &str = include_str!("quantized.metal"); const RANDOM: &str = include_str!("random.metal"); -const REDUCE: &str = include_str!("reduce.metal"); const SORT: &str = include_str!("sort.metal"); const TERNARY: &str = include_str!("ternary.metal"); const UNARY: &str = include_str!("unary.metal"); @@ -54,6 +58,7 @@ impl DType { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { + Candle, Affine, Binary, Cast, @@ -64,7 +69,6 @@ pub enum Source { MlxSort, Quantized, Random, - Reduce, Sort, Ternary, Unary, @@ -288,11 +292,11 @@ impl Kernels { Source::MlxSort => MLX_SORT, Source::Quantized => QUANTIZED, Source::Random => RANDOM, - Source::Reduce => REDUCE, Source::Sort => SORT, Source::Ternary => TERNARY, Source::Unary => UNARY, Source::Sdpa => SDPA, + Source::Candle => panic!("Invalid lib"), } } @@ -307,11 +311,21 @@ impl Kernels { if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { - let lib = { - let source_content = self.get_library_source(source); - device - .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + let lib = match source { + Source::Candle => { + let source_data = CANDLE; + device.new_library_with_data(source_data).map_err(|e| { + MetalKernelError::LoadLibraryError(format!( + "Candle metal requires macosx > 13.0 or higher, cannot load candle metal library: {e}" + )) + })? + } + source => { + let source_content = self.get_library_source(source); + device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + } }; libraries.insert(source, lib.clone()); Ok(lib) @@ -641,7 +655,7 @@ pub fn call_reduce_contiguous( let length = shape.iter().product::(); let num_dims = shape.len(); let work_per_threadgroup = length / out_length; - let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); @@ -697,7 +711,7 @@ pub fn call_reduce_strided( let length: usize = shape.iter().product(); let num_dims = shape.len(); let work_per_threadgroup = length / out_length; - let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); @@ -752,7 +766,7 @@ pub fn call_last_softmax( ) -> Result<(), MetalKernelError> { let work_per_threadgroup = elements; - let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -801,7 +815,7 @@ pub fn call_rms_norm( alpha_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -862,7 +876,7 @@ pub fn call_layer_norm( beta_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -923,7 +937,7 @@ pub fn call_rope_i( sin_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -966,7 +980,7 @@ pub fn call_rope_thd( sin_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -1010,7 +1024,7 @@ pub fn call_rope( sin_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 291c81e6..d976d101 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -1,51 +1,8 @@ #include #include +#include "utils.metal" using namespace metal; -METAL_FUNC uint nonzero(uint n) { - return n == 0 ? 1 : n; -} - -template -constexpr uint nonzero() { - return N == 0 ? 1 : N; -} - -template -constexpr ushort granularity() { - return nonzero::value>(); -} - -METAL_FUNC uint next_p2(uint x) { - return 1 << (32 - clz(x - 1)); -} - -METAL_FUNC uint prev_p2(uint x) { - return 1 << (31 - clz(x)); -} - -constant uint MAX_SHARED_MEM = 32767; - -template -METAL_FUNC uint max_shared_mem(uint n) { - return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T))); -} - -METAL_FUNC uint get_strided_index( - uint idx, - constant const uint &num_dims, - constant const size_t *dims, - constant const size_t *strides -) { - uint strided_i = 0; - for (uint d = 0; d < num_dims; d++) { - uint dim_idx = num_dims - 1 - d; - strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; - idx /= dims[dim_idx]; - } - return strided_i; -} - struct Divide { template METAL_FUNC T operator()(T a, T b) { return a / b; }