mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add metal precompilation via build.rs
This commit is contained in:
123
candle-metal-kernels/build.rs
Normal file
123
candle-metal-kernels/build.rs
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
use std::path::PathBuf;
|
||||||
|
use std::process::Command;
|
||||||
|
use std::{env, str};
|
||||||
|
|
||||||
|
const METAL_SOURCES: [&str; 1] = ["reduce"];
|
||||||
|
|
||||||
|
enum Platform {
|
||||||
|
MacOS,
|
||||||
|
IOS,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Platform {
|
||||||
|
fn sdk(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
Platform::MacOS => "macosx",
|
||||||
|
Platform::IOS => "iphoneos",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compile(platform: Platform) -> Result<(), String> {
|
||||||
|
println!("cargo::rerun-if-changed=src/reduce.metal");
|
||||||
|
println!("cargo::rerun-if-changed=src/utils.metal");
|
||||||
|
println!("cargo::rerun-if-changed=build.rs");
|
||||||
|
|
||||||
|
let current_dir = env::current_dir().expect("Failed to get current directory");
|
||||||
|
let out_dir = PathBuf::from(std::env::var("OUT_DIR").map_err(|_| "OUT_DIR not set")?);
|
||||||
|
let working_directory = out_dir.to_string_lossy().to_string();
|
||||||
|
let sources = current_dir.join("src");
|
||||||
|
|
||||||
|
// Compile metal to air
|
||||||
|
let mut compile_air_cmd = Command::new("xcrun");
|
||||||
|
compile_air_cmd
|
||||||
|
.arg("--sdk")
|
||||||
|
.arg(platform.sdk())
|
||||||
|
.arg("metal")
|
||||||
|
.arg(format!("-working-directory={working_directory}"))
|
||||||
|
.arg("-Wall")
|
||||||
|
.arg("-Wextra")
|
||||||
|
.arg("-O3")
|
||||||
|
.arg("-c")
|
||||||
|
.arg("-w");
|
||||||
|
for metal_file in METAL_SOURCES {
|
||||||
|
compile_air_cmd.arg(sources.join(format!("{metal_file}.metal")));
|
||||||
|
}
|
||||||
|
compile_air_cmd.arg(sources.join("utils.metal"));
|
||||||
|
compile_air_cmd.spawn().expect("Failed to compile air");
|
||||||
|
|
||||||
|
let mut child = compile_air_cmd.spawn().expect("Failed to compile air");
|
||||||
|
|
||||||
|
match child.try_wait() {
|
||||||
|
Ok(Some(status)) => {
|
||||||
|
if !status.success() {
|
||||||
|
panic!(
|
||||||
|
"Compiling metal -> air failed. Exit with status: {}",
|
||||||
|
status
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
let status = child
|
||||||
|
.wait()
|
||||||
|
.expect("Compiling metal -> air failed while waiting for result");
|
||||||
|
if !status.success() {
|
||||||
|
panic!(
|
||||||
|
"Compiling metal -> air failed. Exit with status: {}",
|
||||||
|
status
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => panic!("Compiling metal -> air failed: {:?}", e),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile air to metallib
|
||||||
|
let lib_name = match platform {
|
||||||
|
Platform::MacOS => "candle.metallib",
|
||||||
|
Platform::IOS => "candle_ios.metallib",
|
||||||
|
};
|
||||||
|
let metallib = out_dir.join(lib_name);
|
||||||
|
let mut compile_metallib_cmd = Command::new("xcrun");
|
||||||
|
compile_metallib_cmd.arg("metal").arg("-o").arg(&metallib);
|
||||||
|
|
||||||
|
for metal_file in METAL_SOURCES {
|
||||||
|
compile_metallib_cmd.arg(out_dir.join(format!("{metal_file}.air")));
|
||||||
|
}
|
||||||
|
compile_metallib_cmd.arg(out_dir.join("utils.air"));
|
||||||
|
|
||||||
|
let mut child = compile_metallib_cmd
|
||||||
|
.spawn()
|
||||||
|
.expect("Failed to compile air -> metallib");
|
||||||
|
|
||||||
|
match child.try_wait() {
|
||||||
|
Ok(Some(status)) => {
|
||||||
|
if !status.success() {
|
||||||
|
panic!(
|
||||||
|
"Compiling air -> metallib failed. Exit with status: {}",
|
||||||
|
status
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
let status = child
|
||||||
|
.wait()
|
||||||
|
.expect("Compiling air -> metallib failed while waiting for result");
|
||||||
|
if !status.success() {
|
||||||
|
panic!(
|
||||||
|
"Compiling air -> metallib failed. Exit with status: {}",
|
||||||
|
status
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => panic!("Compiling air -> metallib failed: {:?}", e),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<(), String> {
|
||||||
|
compile(Platform::MacOS)?;
|
||||||
|
compile(Platform::IOS)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -13,6 +13,11 @@ pub use sort::{call_arg_sort, call_mlx_arg_sort};
|
|||||||
pub use utils::BufferOffset;
|
pub use utils::BufferOffset;
|
||||||
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
|
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
|
||||||
|
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
const CANDLE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/candle.metallib"));
|
||||||
|
#[cfg(target_os = "ios")]
|
||||||
|
const CANDLE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/candle_ios.metallib"));
|
||||||
|
|
||||||
const AFFINE: &str = include_str!("affine.metal");
|
const AFFINE: &str = include_str!("affine.metal");
|
||||||
const BINARY: &str = include_str!("binary.metal");
|
const BINARY: &str = include_str!("binary.metal");
|
||||||
const CAST: &str = include_str!("cast.metal");
|
const CAST: &str = include_str!("cast.metal");
|
||||||
@ -23,7 +28,6 @@ const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
|||||||
const MLX_SORT: &str = include_str!("mlx_sort.metal");
|
const MLX_SORT: &str = include_str!("mlx_sort.metal");
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||||
const RANDOM: &str = include_str!("random.metal");
|
const RANDOM: &str = include_str!("random.metal");
|
||||||
const REDUCE: &str = include_str!("reduce.metal");
|
|
||||||
const SORT: &str = include_str!("sort.metal");
|
const SORT: &str = include_str!("sort.metal");
|
||||||
const TERNARY: &str = include_str!("ternary.metal");
|
const TERNARY: &str = include_str!("ternary.metal");
|
||||||
const UNARY: &str = include_str!("unary.metal");
|
const UNARY: &str = include_str!("unary.metal");
|
||||||
@ -54,6 +58,7 @@ impl DType {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub enum Source {
|
pub enum Source {
|
||||||
|
Candle,
|
||||||
Affine,
|
Affine,
|
||||||
Binary,
|
Binary,
|
||||||
Cast,
|
Cast,
|
||||||
@ -64,7 +69,6 @@ pub enum Source {
|
|||||||
MlxSort,
|
MlxSort,
|
||||||
Quantized,
|
Quantized,
|
||||||
Random,
|
Random,
|
||||||
Reduce,
|
|
||||||
Sort,
|
Sort,
|
||||||
Ternary,
|
Ternary,
|
||||||
Unary,
|
Unary,
|
||||||
@ -288,11 +292,11 @@ impl Kernels {
|
|||||||
Source::MlxSort => MLX_SORT,
|
Source::MlxSort => MLX_SORT,
|
||||||
Source::Quantized => QUANTIZED,
|
Source::Quantized => QUANTIZED,
|
||||||
Source::Random => RANDOM,
|
Source::Random => RANDOM,
|
||||||
Source::Reduce => REDUCE,
|
|
||||||
Source::Sort => SORT,
|
Source::Sort => SORT,
|
||||||
Source::Ternary => TERNARY,
|
Source::Ternary => TERNARY,
|
||||||
Source::Unary => UNARY,
|
Source::Unary => UNARY,
|
||||||
Source::Sdpa => SDPA,
|
Source::Sdpa => SDPA,
|
||||||
|
Source::Candle => panic!("Invalid lib"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -307,11 +311,21 @@ impl Kernels {
|
|||||||
if let Some(lib) = libraries.get(&source) {
|
if let Some(lib) = libraries.get(&source) {
|
||||||
Ok(lib.clone())
|
Ok(lib.clone())
|
||||||
} else {
|
} else {
|
||||||
let lib = {
|
let lib = match source {
|
||||||
let source_content = self.get_library_source(source);
|
Source::Candle => {
|
||||||
device
|
let source_data = CANDLE;
|
||||||
.new_library_with_source(source_content, &CompileOptions::new())
|
device.new_library_with_data(source_data).map_err(|e| {
|
||||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
MetalKernelError::LoadLibraryError(format!(
|
||||||
|
"Candle metal requires macosx > 13.0 or higher, cannot load candle metal library: {e}"
|
||||||
|
))
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
source => {
|
||||||
|
let source_content = self.get_library_source(source);
|
||||||
|
device
|
||||||
|
.new_library_with_source(source_content, &CompileOptions::new())
|
||||||
|
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
libraries.insert(source, lib.clone());
|
libraries.insert(source, lib.clone());
|
||||||
Ok(lib)
|
Ok(lib)
|
||||||
@ -641,7 +655,7 @@ pub fn call_reduce_contiguous(
|
|||||||
let length = shape.iter().product::<usize>();
|
let length = shape.iter().product::<usize>();
|
||||||
let num_dims = shape.len();
|
let num_dims = shape.len();
|
||||||
let work_per_threadgroup = length / out_length;
|
let work_per_threadgroup = length / out_length;
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||||
|
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
@ -697,7 +711,7 @@ pub fn call_reduce_strided(
|
|||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
let num_dims = shape.len();
|
let num_dims = shape.len();
|
||||||
let work_per_threadgroup = length / out_length;
|
let work_per_threadgroup = length / out_length;
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||||
|
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
@ -752,7 +766,7 @@ pub fn call_last_softmax(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let work_per_threadgroup = elements;
|
let work_per_threadgroup = elements;
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -801,7 +815,7 @@ pub fn call_rms_norm(
|
|||||||
alpha_offset: usize,
|
alpha_offset: usize,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -862,7 +876,7 @@ pub fn call_layer_norm(
|
|||||||
beta_offset: usize,
|
beta_offset: usize,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -923,7 +937,7 @@ pub fn call_rope_i(
|
|||||||
sin_offset: usize,
|
sin_offset: usize,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -966,7 +980,7 @@ pub fn call_rope_thd(
|
|||||||
sin_offset: usize,
|
sin_offset: usize,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -1010,7 +1024,7 @@ pub fn call_rope(
|
|||||||
sin_offset: usize,
|
sin_offset: usize,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
@ -1,51 +1,8 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
#include <metal_limits>
|
#include <metal_limits>
|
||||||
|
#include "utils.metal"
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
METAL_FUNC uint nonzero(uint n) {
|
|
||||||
return n == 0 ? 1 : n;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<uint N>
|
|
||||||
constexpr uint nonzero() {
|
|
||||||
return N == 0 ? 1 : N;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
constexpr ushort granularity() {
|
|
||||||
return nonzero<vec_elements<T>::value>();
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC uint next_p2(uint x) {
|
|
||||||
return 1 << (32 - clz(x - 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC uint prev_p2(uint x) {
|
|
||||||
return 1 << (31 - clz(x));
|
|
||||||
}
|
|
||||||
|
|
||||||
constant uint MAX_SHARED_MEM = 32767;
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
METAL_FUNC uint max_shared_mem(uint n) {
|
|
||||||
return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T)));
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
|
||||||
uint idx,
|
|
||||||
constant const uint &num_dims,
|
|
||||||
constant const size_t *dims,
|
|
||||||
constant const size_t *strides
|
|
||||||
) {
|
|
||||||
uint strided_i = 0;
|
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
|
||||||
uint dim_idx = num_dims - 1 - d;
|
|
||||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
|
||||||
idx /= dims[dim_idx];
|
|
||||||
}
|
|
||||||
return strided_i;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Divide {
|
struct Divide {
|
||||||
template<typename T>
|
template<typename T>
|
||||||
METAL_FUNC T operator()(T a, T b) { return a / b; }
|
METAL_FUNC T operator()(T a, T b) { return a / b; }
|
||||||
|
Reference in New Issue
Block a user