mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Compare commits
1 Commits
0.6.0
...
precompile
Author | SHA1 | Date | |
---|---|---|---|
5ac3302fac |
@ -588,6 +588,7 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
|
2
candle-metal-kernels/.gitignore
vendored
Normal file
2
candle-metal-kernels/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
src/compiled/
|
||||
|
45
candle-metal-kernels/build.rs
Normal file
45
candle-metal-kernels/build.rs
Normal file
@ -0,0 +1,45 @@
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let files: std::fs::ReadDir = std::fs::read_dir("src/").unwrap();
|
||||
for file in files {
|
||||
let file = file?;
|
||||
let path = file.path();
|
||||
if let Some(extension) = path.extension() {
|
||||
if extension == "metal" {
|
||||
build_kernel(&path)?;
|
||||
}
|
||||
println!("cargo:warning=output {:?}", path.file_stem());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_kernel(path: &Path) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let stem = path
|
||||
.file_stem()
|
||||
.expect("expect real filename")
|
||||
.to_str()
|
||||
.expect("expect real stem");
|
||||
Command::new("xcrun")
|
||||
.args([
|
||||
"metal",
|
||||
"-c",
|
||||
path.as_os_str().to_str().expect("Expect a real filename"),
|
||||
"-I",
|
||||
"src/",
|
||||
"-o",
|
||||
&format!("src/compiled/{stem}.air"),
|
||||
])
|
||||
.output()?;
|
||||
Command::new("xcrun")
|
||||
.args([
|
||||
"metallib",
|
||||
&format!("src/compiled/{stem}.air"),
|
||||
"-o",
|
||||
&format!("src/compiled/{stem}.metallib"),
|
||||
])
|
||||
.output()?;
|
||||
Ok(())
|
||||
}
|
@ -73,6 +73,7 @@ kernel void FN_NAME_STRIDED( \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||
@ -95,4 +96,4 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
||||
#endif
|
||||
|
@ -1,22 +1,22 @@
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
|
||||
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const AFFINE: &[u8] = include_bytes!("compiled/affine.metallib");
|
||||
const INDEXING: &[u8] = include_bytes!("compiled/indexing.metallib");
|
||||
const UNARY: &[u8] = include_bytes!("compiled/unary.metallib");
|
||||
const BINARY: &[u8] = include_bytes!("compiled/binary.metallib");
|
||||
const TERNARY: &[u8] = include_bytes!("compiled/ternary.metallib");
|
||||
const CAST: &[u8] = include_bytes!("compiled/cast.metallib");
|
||||
const CONV: &[u8] = include_bytes!("compiled/conv.metallib");
|
||||
const REDUCE: &[u8] = include_bytes!("compiled/reduce.metallib");
|
||||
const RANDOM: &[u8] = include_bytes!("compiled/random.metallib");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const QUANTIZED: &[u8] = include_bytes!("compiled/quantized.metallib");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
@ -235,7 +235,7 @@ impl Kernels {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_library_source(&self, source: Source) -> &'static str {
|
||||
fn get_library_source(&self, source: Source) -> &'static [u8] {
|
||||
match source {
|
||||
Source::Affine => AFFINE,
|
||||
Source::Unary => UNARY,
|
||||
@ -247,7 +247,7 @@ impl Kernels {
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
Source::Mfa => MFA,
|
||||
}
|
||||
}
|
||||
|
||||
@ -262,22 +262,12 @@ impl Kernels {
|
||||
if let Some(lib) = libraries.get(&source) {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let lib = match source {
|
||||
Source::Mfa => {
|
||||
let source_data = MFA;
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
}
|
||||
};
|
||||
let source_data = self.get_library_source(source);
|
||||
let lib = device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?;
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
}
|
||||
|
Reference in New Issue
Block a user