Add back the bf16 flash-attn kernels. (#730)

This commit is contained in:
Laurent Mazare
2023-09-04 08:50:52 +02:00
committed by GitHub
parent 20512ba408
commit d0cdea95a5
4 changed files with 25 additions and 22 deletions

View File

@ -6,7 +6,7 @@ use rayon::prelude::*;
use std::path::PathBuf;
use std::str::FromStr;
const KERNEL_FILES: [&str; 9] = [
const KERNEL_FILES: [&str; 17] = [
"flash_api.cu",
"flash_fwd_hdim128_fp16_sm80.cu",
"flash_fwd_hdim160_fp16_sm80.cu",
@ -16,14 +16,14 @@ const KERNEL_FILES: [&str; 9] = [
"flash_fwd_hdim32_fp16_sm80.cu",
"flash_fwd_hdim64_fp16_sm80.cu",
"flash_fwd_hdim96_fp16_sm80.cu",
// "flash_fwd_hdim128_bf16_sm80.cu",
// "flash_fwd_hdim160_bf16_sm80.cu",
// "flash_fwd_hdim192_bf16_sm80.cu",
// "flash_fwd_hdim224_bf16_sm80.cu",
// "flash_fwd_hdim256_bf16_sm80.cu",
// "flash_fwd_hdim32_bf16_sm80.cu",
// "flash_fwd_hdim64_bf16_sm80.cu",
// "flash_fwd_hdim96_bf16_sm80.cu",
"flash_fwd_hdim128_bf16_sm80.cu",
"flash_fwd_hdim160_bf16_sm80.cu",
"flash_fwd_hdim192_bf16_sm80.cu",
"flash_fwd_hdim224_bf16_sm80.cu",
"flash_fwd_hdim256_bf16_sm80.cu",
"flash_fwd_hdim32_bf16_sm80.cu",
"flash_fwd_hdim64_bf16_sm80.cu",
"flash_fwd_hdim96_bf16_sm80.cu",
];
fn main() -> Result<()> {