Compare commits

...

1 Commits

Author SHA1 Message Date
5ac3302fac Prebuild all our kernels. 2024-03-18 16:39:38 +01:00
5 changed files with 70 additions and 31 deletions

View File

@ -588,6 +588,7 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::F16) => "cast_u32_f16",
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U8, DType::U32) => "cast_u8_u32",

2
candle-metal-kernels/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
src/compiled/

View File

@ -0,0 +1,45 @@
use std::path::Path;
use std::process::Command;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let files: std::fs::ReadDir = std::fs::read_dir("src/").unwrap();
for file in files {
let file = file?;
let path = file.path();
if let Some(extension) = path.extension() {
if extension == "metal" {
build_kernel(&path)?;
}
println!("cargo:warning=output {:?}", path.file_stem());
}
}
Ok(())
}
fn build_kernel(path: &Path) -> Result<(), Box<dyn std::error::Error>> {
let stem = path
.file_stem()
.expect("expect real filename")
.to_str()
.expect("expect real stem");
Command::new("xcrun")
.args([
"metal",
"-c",
path.as_os_str().to_str().expect("Expect a real filename"),
"-I",
"src/",
"-o",
&format!("src/compiled/{stem}.air"),
])
.output()?;
Command::new("xcrun")
.args([
"metallib",
&format!("src/compiled/{stem}.air"),
"-o",
&format!("src/compiled/{stem}.metallib"),
])
.output()?;
Ok(())
}

View File

@ -73,6 +73,7 @@ kernel void FN_NAME_STRIDED( \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
@ -95,4 +96,4 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif
#endif

View File

@ -1,22 +1,22 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const AFFINE: &[u8] = include_bytes!("compiled/affine.metallib");
const INDEXING: &[u8] = include_bytes!("compiled/indexing.metallib");
const UNARY: &[u8] = include_bytes!("compiled/unary.metallib");
const BINARY: &[u8] = include_bytes!("compiled/binary.metallib");
const TERNARY: &[u8] = include_bytes!("compiled/ternary.metallib");
const CAST: &[u8] = include_bytes!("compiled/cast.metallib");
const CONV: &[u8] = include_bytes!("compiled/conv.metallib");
const REDUCE: &[u8] = include_bytes!("compiled/reduce.metallib");
const RANDOM: &[u8] = include_bytes!("compiled/random.metallib");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
const QUANTIZED: &[u8] = include_bytes!("compiled/quantized.metallib");
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
@ -235,7 +235,7 @@ impl Kernels {
}
}
fn get_library_source(&self, source: Source) -> &'static str {
fn get_library_source(&self, source: Source) -> &'static [u8] {
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
@ -247,7 +247,7 @@ impl Kernels {
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Mfa => panic!("Invalid lib"),
Source::Mfa => MFA,
}
}
@ -262,22 +262,12 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
let lib = match source {
Source::Mfa => {
let source_data = MFA;
device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?
}
source => {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
}
};
let source_data = self.get_library_source(source);
let lib = device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?;
libraries.insert(source, lib.clone());
Ok(lib)
}