Update the flash attn kernels. (#2333)

This commit is contained in:
Laurent Mazare
2024-07-15 20:37:36 +02:00
committed by GitHub
parent d74fbed334
commit 30cdd769f9
51 changed files with 2279 additions and 904 deletions

View File

@ -4,7 +4,7 @@
use anyhow::{Context, Result};
use std::path::PathBuf;
const KERNEL_FILES: [&str; 17] = [
const KERNEL_FILES: [&str; 33] = [
"kernels/flash_api.cu",
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
@ -22,6 +22,22 @@ const KERNEL_FILES: [&str; 17] = [
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
"kernels/flash_fwd_hdim128_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim160_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim192_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim224_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim256_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim32_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim64_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim96_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim128_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim160_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim192_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim224_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim256_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim32_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim64_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim96_bf16_causal_sm80.cu",
];
fn main() -> Result<()> {