mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add back the bf16 flash-attn kernels. (#730)
This commit is contained in:
@ -6,7 +6,7 @@ use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
|
||||
const KERNEL_FILES: [&str; 9] = [
|
||||
const KERNEL_FILES: [&str; 17] = [
|
||||
"flash_api.cu",
|
||||
"flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"flash_fwd_hdim160_fp16_sm80.cu",
|
||||
@ -16,14 +16,14 @@ const KERNEL_FILES: [&str; 9] = [
|
||||
"flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"flash_fwd_hdim96_fp16_sm80.cu",
|
||||
// "flash_fwd_hdim128_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim160_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim192_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim224_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim256_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim32_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim64_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim96_bf16_sm80.cu",
|
||||
"flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"flash_fwd_hdim96_bf16_sm80.cu",
|
||||
];
|
||||
|
||||
fn main() -> Result<()> {
|
||||
|
Reference in New Issue
Block a user