mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Compare commits
1 Commits
0.9.1
...
precompile
Author | SHA1 | Date | |
---|---|---|---|
5ac3302fac |
@ -588,6 +588,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||||
|
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||||
|
|
||||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||||
|
2
candle-metal-kernels/.gitignore
vendored
Normal file
2
candle-metal-kernels/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
src/compiled/
|
||||||
|
|
45
candle-metal-kernels/build.rs
Normal file
45
candle-metal-kernels/build.rs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
use std::path::Path;
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let files: std::fs::ReadDir = std::fs::read_dir("src/").unwrap();
|
||||||
|
for file in files {
|
||||||
|
let file = file?;
|
||||||
|
let path = file.path();
|
||||||
|
if let Some(extension) = path.extension() {
|
||||||
|
if extension == "metal" {
|
||||||
|
build_kernel(&path)?;
|
||||||
|
}
|
||||||
|
println!("cargo:warning=output {:?}", path.file_stem());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_kernel(path: &Path) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let stem = path
|
||||||
|
.file_stem()
|
||||||
|
.expect("expect real filename")
|
||||||
|
.to_str()
|
||||||
|
.expect("expect real stem");
|
||||||
|
Command::new("xcrun")
|
||||||
|
.args([
|
||||||
|
"metal",
|
||||||
|
"-c",
|
||||||
|
path.as_os_str().to_str().expect("Expect a real filename"),
|
||||||
|
"-I",
|
||||||
|
"src/",
|
||||||
|
"-o",
|
||||||
|
&format!("src/compiled/{stem}.air"),
|
||||||
|
])
|
||||||
|
.output()?;
|
||||||
|
Command::new("xcrun")
|
||||||
|
.args([
|
||||||
|
"metallib",
|
||||||
|
&format!("src/compiled/{stem}.air"),
|
||||||
|
"-o",
|
||||||
|
&format!("src/compiled/{stem}.metallib"),
|
||||||
|
])
|
||||||
|
.output()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -73,6 +73,7 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
} \
|
} \
|
||||||
|
|
||||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||||
|
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
|
||||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||||
@ -95,4 +96,4 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
|||||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||||
#endif
|
#endif
|
||||||
|
@ -1,22 +1,22 @@
|
|||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
|
||||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
const AFFINE: &str = include_str!("affine.metal");
|
const AFFINE: &[u8] = include_bytes!("compiled/affine.metallib");
|
||||||
const INDEXING: &str = include_str!("indexing.metal");
|
const INDEXING: &[u8] = include_bytes!("compiled/indexing.metallib");
|
||||||
const UNARY: &str = include_str!("unary.metal");
|
const UNARY: &[u8] = include_bytes!("compiled/unary.metallib");
|
||||||
const BINARY: &str = include_str!("binary.metal");
|
const BINARY: &[u8] = include_bytes!("compiled/binary.metallib");
|
||||||
const TERNARY: &str = include_str!("ternary.metal");
|
const TERNARY: &[u8] = include_bytes!("compiled/ternary.metallib");
|
||||||
const CAST: &str = include_str!("cast.metal");
|
const CAST: &[u8] = include_bytes!("compiled/cast.metallib");
|
||||||
const CONV: &str = include_str!("conv.metal");
|
const CONV: &[u8] = include_bytes!("compiled/conv.metallib");
|
||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &[u8] = include_bytes!("compiled/reduce.metallib");
|
||||||
const RANDOM: &str = include_str!("random.metal");
|
const RANDOM: &[u8] = include_bytes!("compiled/random.metallib");
|
||||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
const QUANTIZED: &[u8] = include_bytes!("compiled/quantized.metallib");
|
||||||
|
|
||||||
/// Most kernels apply similarly across the tensors
|
/// Most kernels apply similarly across the tensors
|
||||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||||
@ -235,7 +235,7 @@ impl Kernels {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_library_source(&self, source: Source) -> &'static str {
|
fn get_library_source(&self, source: Source) -> &'static [u8] {
|
||||||
match source {
|
match source {
|
||||||
Source::Affine => AFFINE,
|
Source::Affine => AFFINE,
|
||||||
Source::Unary => UNARY,
|
Source::Unary => UNARY,
|
||||||
@ -247,7 +247,7 @@ impl Kernels {
|
|||||||
Source::Conv => CONV,
|
Source::Conv => CONV,
|
||||||
Source::Random => RANDOM,
|
Source::Random => RANDOM,
|
||||||
Source::Quantized => QUANTIZED,
|
Source::Quantized => QUANTIZED,
|
||||||
Source::Mfa => panic!("Invalid lib"),
|
Source::Mfa => MFA,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,22 +262,12 @@ impl Kernels {
|
|||||||
if let Some(lib) = libraries.get(&source) {
|
if let Some(lib) = libraries.get(&source) {
|
||||||
Ok(lib.clone())
|
Ok(lib.clone())
|
||||||
} else {
|
} else {
|
||||||
let lib = match source {
|
let source_data = self.get_library_source(source);
|
||||||
Source::Mfa => {
|
let lib = device.new_library_with_data(source_data).map_err(|e| {
|
||||||
let source_data = MFA;
|
MetalKernelError::LoadLibraryError(format!(
|
||||||
device.new_library_with_data(source_data).map_err(|e| {
|
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||||
MetalKernelError::LoadLibraryError(format!(
|
))
|
||||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
})?;
|
||||||
))
|
|
||||||
})?
|
|
||||||
}
|
|
||||||
source => {
|
|
||||||
let source_content = self.get_library_source(source);
|
|
||||||
device
|
|
||||||
.new_library_with_source(source_content, &CompileOptions::new())
|
|
||||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
libraries.insert(source, lib.clone());
|
libraries.insert(source, lib.clone());
|
||||||
Ok(lib)
|
Ok(lib)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user