diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index c19d7c56..020e6679 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -588,6 +588,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", (DType::U32, DType::I64) => "cast_u32_i64", + (DType::U32, DType::F16) => "cast_u32_f16", (DType::U32, DType::BF16) => "cast_u32_bf16", (DType::U8, DType::U32) => "cast_u8_u32", diff --git a/candle-metal-kernels/.gitignore b/candle-metal-kernels/.gitignore new file mode 100644 index 00000000..0de0261d --- /dev/null +++ b/candle-metal-kernels/.gitignore @@ -0,0 +1,2 @@ +src/compiled/ + diff --git a/candle-metal-kernels/build.rs b/candle-metal-kernels/build.rs new file mode 100644 index 00000000..157d33cd --- /dev/null +++ b/candle-metal-kernels/build.rs @@ -0,0 +1,45 @@ +use std::path::Path; +use std::process::Command; + +fn main() -> Result<(), Box> { + let files: std::fs::ReadDir = std::fs::read_dir("src/").unwrap(); + for file in files { + let file = file?; + let path = file.path(); + if let Some(extension) = path.extension() { + if extension == "metal" { + build_kernel(&path)?; + } + println!("cargo:warning=output {:?}", path.file_stem()); + } + } + Ok(()) +} + +fn build_kernel(path: &Path) -> Result<(), Box> { + let stem = path + .file_stem() + .expect("expect real filename") + .to_str() + .expect("expect real stem"); + Command::new("xcrun") + .args([ + "metal", + "-c", + path.as_os_str().to_str().expect("Expect a real filename"), + "-I", + "src/", + "-o", + &format!("src/compiled/{stem}.air"), + ]) + .output()?; + Command::new("xcrun") + .args([ + "metallib", + &format!("src/compiled/{stem}.air"), + "-o", + &format!("src/compiled/{stem}.metallib"), + ]) + .output()?; + Ok(()) +} diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 9aead139..49c11218 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -73,6 +73,7 @@ kernel void FN_NAME_STRIDED( \ } \ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) +CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) @@ -95,4 +96,4 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) -#endif \ No newline at end of file +#endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 33bc3453..ec32a798 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,22 +1,22 @@ use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, - Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, + Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, + FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; -const AFFINE: &str = include_str!("affine.metal"); -const INDEXING: &str = include_str!("indexing.metal"); -const UNARY: &str = include_str!("unary.metal"); -const BINARY: &str = include_str!("binary.metal"); -const TERNARY: &str = include_str!("ternary.metal"); -const CAST: &str = include_str!("cast.metal"); -const CONV: &str = include_str!("conv.metal"); -const REDUCE: &str = include_str!("reduce.metal"); -const RANDOM: &str = include_str!("random.metal"); +const AFFINE: &[u8] = include_bytes!("compiled/affine.metallib"); +const INDEXING: &[u8] = include_bytes!("compiled/indexing.metallib"); +const UNARY: &[u8] = include_bytes!("compiled/unary.metallib"); +const BINARY: &[u8] = include_bytes!("compiled/binary.metallib"); +const TERNARY: &[u8] = include_bytes!("compiled/ternary.metallib"); +const CAST: &[u8] = include_bytes!("compiled/cast.metallib"); +const CONV: &[u8] = include_bytes!("compiled/conv.metallib"); +const REDUCE: &[u8] = include_bytes!("compiled/reduce.metallib"); +const RANDOM: &[u8] = include_bytes!("compiled/random.metallib"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); -const QUANTIZED: &str = include_str!("quantized.metal"); +const QUANTIZED: &[u8] = include_bytes!("compiled/quantized.metallib"); /// Most kernels apply similarly across the tensors /// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the @@ -235,7 +235,7 @@ impl Kernels { } } - fn get_library_source(&self, source: Source) -> &'static str { + fn get_library_source(&self, source: Source) -> &'static [u8] { match source { Source::Affine => AFFINE, Source::Unary => UNARY, @@ -247,7 +247,7 @@ impl Kernels { Source::Conv => CONV, Source::Random => RANDOM, Source::Quantized => QUANTIZED, - Source::Mfa => panic!("Invalid lib"), + Source::Mfa => MFA, } } @@ -262,22 +262,12 @@ impl Kernels { if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { - let lib = match source { - Source::Mfa => { - let source_data = MFA; - device.new_library_with_data(source_data).map_err(|e| { - MetalKernelError::LoadLibraryError(format!( - "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}" - )) - })? - } - source => { - let source_content = self.get_library_source(source); - device - .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? - } - }; + let source_data = self.get_library_source(source); + let lib = device.new_library_with_data(source_data).map_err(|e| { + MetalKernelError::LoadLibraryError(format!( + "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}" + )) + })?; libraries.insert(source, lib.clone()); Ok(lib) }