mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Initial generic metallib build.rs script
This commit is contained in:
1
candle-metal-kernels/.gitignore
vendored
Normal file
1
candle-metal-kernels/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
src/air
|
137
candle-metal-kernels/build.rs
Normal file
137
candle-metal-kernels/build.rs
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
use std::process::Command;
|
||||||
|
use std::{env, str};
|
||||||
|
|
||||||
|
const COMPILED_KERNELS: [&str; 1] = ["reduce"];
|
||||||
|
|
||||||
|
enum Platform {
|
||||||
|
MacOS,
|
||||||
|
IOS,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Platform {
|
||||||
|
fn as_str(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Platform::MacOS => "macosx",
|
||||||
|
Platform::IOS => "iphoneos",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_xcode_sdk_path(platform: Platform) -> Result<String, String> {
|
||||||
|
let xcrun_output = Command::new("xcrun")
|
||||||
|
.args(["--sdk", platform.as_str(), "--show-sdk-path"])
|
||||||
|
.output()
|
||||||
|
.expect("xcrun command failed to start");
|
||||||
|
|
||||||
|
Ok(str::from_utf8(&xcrun_output.stdout)
|
||||||
|
.expect("Invalid UTF-8 from xcrun")
|
||||||
|
.replace('\n', ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compile_candle_metallib(sdk_path: String, bfloat_support: bool) -> Result<(), String> {
|
||||||
|
let current_dir = env::current_dir().expect("Failed to get current directory");
|
||||||
|
let out_dir = current_dir.join("src/libraries");
|
||||||
|
let air_dir = current_dir.join("src/air");
|
||||||
|
let working_directory = air_dir.display();
|
||||||
|
let sources = current_dir.join("src/kernels");
|
||||||
|
|
||||||
|
// Compile metal to air
|
||||||
|
let mut compile_air_cmd = Command::new("xcrun");
|
||||||
|
compile_air_cmd
|
||||||
|
.arg("metal")
|
||||||
|
.arg(format!("-working-directory={working_directory}"))
|
||||||
|
.arg("-Wall")
|
||||||
|
.arg("-Wextra")
|
||||||
|
.arg("-O3")
|
||||||
|
.arg("-c")
|
||||||
|
.arg("-w");
|
||||||
|
for metal_file in COMPILED_KERNELS {
|
||||||
|
compile_air_cmd.arg(sources.join(format!("{metal_file}.metal")));
|
||||||
|
}
|
||||||
|
compile_air_cmd.arg(sources.join("utils.metal"));
|
||||||
|
compile_air_cmd.spawn().expect("Failed to compile air");
|
||||||
|
|
||||||
|
let mut child = compile_air_cmd.spawn().expect("Failed to compile air");
|
||||||
|
|
||||||
|
match child.try_wait() {
|
||||||
|
Ok(Some(status)) => {
|
||||||
|
if !status.success() {
|
||||||
|
panic!(
|
||||||
|
"Compiling metal -> air failed. Exit with status: {}",
|
||||||
|
status
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
let status = child
|
||||||
|
.wait()
|
||||||
|
.expect("Compiling metal -> air failed while waiting for result");
|
||||||
|
if !status.success() {
|
||||||
|
panic!(
|
||||||
|
"Compiling metal -> air failed. Exit with status: {}",
|
||||||
|
status
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => panic!("Compiling metal -> air failed: {:?}", e),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile air to metallib
|
||||||
|
let metallib = out_dir.join("candle.metallib");
|
||||||
|
|
||||||
|
let mut compile_metallib_cmd = Command::new("xcrun");
|
||||||
|
compile_metallib_cmd.arg("metal").arg("-o").arg(&metallib);
|
||||||
|
|
||||||
|
for metal_file in COMPILED_KERNELS {
|
||||||
|
compile_metallib_cmd.arg(air_dir.join(format!("{metal_file}.air")));
|
||||||
|
}
|
||||||
|
compile_metallib_cmd.arg(air_dir.join("utils.air"));
|
||||||
|
|
||||||
|
let mut child = compile_metallib_cmd
|
||||||
|
.spawn()
|
||||||
|
.expect("Failed to compile air -> metallib");
|
||||||
|
|
||||||
|
match child.try_wait() {
|
||||||
|
Ok(Some(status)) => {
|
||||||
|
if !status.success() {
|
||||||
|
panic!(
|
||||||
|
"Compiling air -> metallib failed. Exit with status: {}",
|
||||||
|
status
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
let status = child
|
||||||
|
.wait()
|
||||||
|
.expect("Compiling air -> metallib failed while waiting for result");
|
||||||
|
if !status.success() {
|
||||||
|
panic!(
|
||||||
|
"Compiling air -> metallib failed. Exit with status: {}",
|
||||||
|
status
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => panic!("Compiling air -> metallib failed: {:?}", e),
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<(), String> {
|
||||||
|
println!("cargo::rerun-if-changed=build.rs");
|
||||||
|
|
||||||
|
let current_dir = env::current_dir().expect("Failed to get current directory");
|
||||||
|
let sources = current_dir.join("src/kernels");
|
||||||
|
|
||||||
|
for metal_file in COMPILED_KERNELS {
|
||||||
|
println!("cargo::rerun-if-changed={}", sources.join(format!("{metal_file}.metal")).display());
|
||||||
|
println!("cargo:warning=output {}", sources.join(format!("{metal_file}.metal")).display());
|
||||||
|
}
|
||||||
|
|
||||||
|
let macos_sdk = get_xcode_sdk_path(Platform::MacOS).expect("Failed to get MacOS SDK path");
|
||||||
|
let iphoneos_sdk = get_xcode_sdk_path(Platform::IOS).expect("Failed to get IOS SDK path");
|
||||||
|
|
||||||
|
compile_candle_metallib(macos_sdk, false)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -104,7 +104,7 @@ METAL_FUNC void argmax(
|
|||||||
threadgroup T * shared_memory,
|
threadgroup T * shared_memory,
|
||||||
threadgroup uint * shared_indices
|
threadgroup uint * shared_indices
|
||||||
) {
|
) {
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
// to (dst_id + 1) * el_to_sum_per_block.
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||||
@ -173,7 +173,7 @@ METAL_FUNC void reduce(
|
|||||||
threadgroup T * shared_memory,
|
threadgroup T * shared_memory,
|
||||||
T (*fn)(T, T)
|
T (*fn)(T, T)
|
||||||
) {
|
) {
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||||
// to (dst_id + 1) * el_to_sum_per_block.
|
// to (dst_id + 1) * el_to_sum_per_block.
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
size_t stop_idx = start_idx + el_to_sum_per_block;
|
47
candle-metal-kernels/src/kernels/utils.metal
Normal file
47
candle-metal-kernels/src/kernels/utils.metal
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
METAL_FUNC uint nonzero(uint n) {
|
||||||
|
return n == 0 ? 1 : n;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<uint N>
|
||||||
|
constexpr uint nonzero() {
|
||||||
|
return N == 0 ? 1 : N;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
constexpr ushort granularity() {
|
||||||
|
return nonzero<vec_elements<T>::value>();
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC uint next_p2(uint x) {
|
||||||
|
return 1 << (32 - clz(x - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC uint prev_p2(uint x) {
|
||||||
|
return 1 << (31 - clz(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
constant uint MAX_SHARED_MEM = 32767;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
METAL_FUNC uint max_shared_mem(uint n) {
|
||||||
|
return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T)));
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC uint get_strided_index(
|
||||||
|
uint idx,
|
||||||
|
constant const uint &num_dims,
|
||||||
|
constant const size_t *dims,
|
||||||
|
constant const size_t *strides
|
||||||
|
) {
|
||||||
|
uint strided_i = 0;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
@ -10,21 +10,23 @@ mod utils;
|
|||||||
pub use utils::BufferOffset;
|
pub use utils::BufferOffset;
|
||||||
use utils::{get_block_dims, linear_split};
|
use utils::{get_block_dims, linear_split};
|
||||||
|
|
||||||
const AFFINE: &str = include_str!("affine.metal");
|
const AFFINE: &str = include_str!("kernels/affine.metal");
|
||||||
const INDEXING: &str = include_str!("indexing.metal");
|
const INDEXING: &str = include_str!("kernels/indexing.metal");
|
||||||
const UNARY: &str = include_str!("unary.metal");
|
const UNARY: &str = include_str!("kernels/unary.metal");
|
||||||
const BINARY: &str = include_str!("binary.metal");
|
const BINARY: &str = include_str!("kernels/binary.metal");
|
||||||
const TERNARY: &str = include_str!("ternary.metal");
|
const TERNARY: &str = include_str!("kernels/ternary.metal");
|
||||||
const CAST: &str = include_str!("cast.metal");
|
const CAST: &str = include_str!("kernels/cast.metal");
|
||||||
const CONV: &str = include_str!("conv.metal");
|
const CONV: &str = include_str!("kernels/conv.metal");
|
||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &str = include_str!("kernels/reduce.metal");
|
||||||
const RANDOM: &str = include_str!("random.metal");
|
const RANDOM: &str = include_str!("kernels/random.metal");
|
||||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
const QUANTIZED: &str = include_str!("kernels/quantized.metal");
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
const SORT: &str = include_str!("kernels/sort.metal");
|
||||||
const SORT: &str = include_str!("sort.metal");
|
const MFA: &[u8] = include_bytes!("libraries/libMetalFlashAttention.metallib");
|
||||||
|
const CANDLE: &[u8] = include_bytes!("libraries/libMetalFlashAttention.metallib");
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub enum Source {
|
pub enum Source {
|
||||||
|
Candle,
|
||||||
Affine,
|
Affine,
|
||||||
Indexing,
|
Indexing,
|
||||||
Unary,
|
Unary,
|
||||||
@ -200,7 +202,7 @@ impl Kernels {
|
|||||||
Source::Random => RANDOM,
|
Source::Random => RANDOM,
|
||||||
Source::Quantized => QUANTIZED,
|
Source::Quantized => QUANTIZED,
|
||||||
Source::Sort => SORT,
|
Source::Sort => SORT,
|
||||||
Source::Mfa => panic!("Invalid lib"),
|
_ => panic!("Invalid lib"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,9 +218,15 @@ impl Kernels {
|
|||||||
Ok(lib.clone())
|
Ok(lib.clone())
|
||||||
} else {
|
} else {
|
||||||
let lib = match source {
|
let lib = match source {
|
||||||
|
Source::Candle => {
|
||||||
|
device.new_library_with_data(CANDLE).map_err(|e| {
|
||||||
|
MetalKernelError::LoadLibraryError(format!(
|
||||||
|
"Candle metal requires macosx > 13.0 or higher, cannot load candle: {e}"
|
||||||
|
))
|
||||||
|
})?
|
||||||
|
}
|
||||||
Source::Mfa => {
|
Source::Mfa => {
|
||||||
let source_data = MFA;
|
device.new_library_with_data(MFA).map_err(|e| {
|
||||||
device.new_library_with_data(source_data).map_err(|e| {
|
|
||||||
MetalKernelError::LoadLibraryError(format!(
|
MetalKernelError::LoadLibraryError(format!(
|
||||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||||
))
|
))
|
||||||
|
BIN
candle-metal-kernels/src/libraries/candle.metallib
Normal file
BIN
candle-metal-kernels/src/libraries/candle.metallib
Normal file
Binary file not shown.
Reference in New Issue
Block a user