mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add back the bf16 flash-attn kernels. (#730)
This commit is contained in:
@ -6,7 +6,7 @@ use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
|
||||
const KERNEL_FILES: [&str; 9] = [
|
||||
const KERNEL_FILES: [&str; 17] = [
|
||||
"flash_api.cu",
|
||||
"flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"flash_fwd_hdim160_fp16_sm80.cu",
|
||||
@ -16,14 +16,14 @@ const KERNEL_FILES: [&str; 9] = [
|
||||
"flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"flash_fwd_hdim96_fp16_sm80.cu",
|
||||
// "flash_fwd_hdim128_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim160_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim192_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim224_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim256_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim32_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim64_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim96_bf16_sm80.cu",
|
||||
"flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"flash_fwd_hdim96_bf16_sm80.cu",
|
||||
];
|
||||
|
||||
fn main() -> Result<()> {
|
||||
|
@ -1,20 +1,19 @@
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// TODO: Switch back to handling bf16.
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// FP16_SWITCH(!params.is_bf16, [&] {
|
||||
// FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
// });
|
||||
// FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
|
||||
// });
|
||||
// }
|
||||
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
extern "C" void run_mha(
|
||||
void *q_ptr,
|
||||
void *k_ptr,
|
||||
@ -52,7 +51,8 @@ extern "C" void run_mha(
|
||||
uint32_t seqlen_q_rounded,
|
||||
uint32_t seqlen_k_rounded,
|
||||
|
||||
int is_causal
|
||||
int is_causal,
|
||||
int is_bf16
|
||||
) {
|
||||
Flash_fwd_params params;
|
||||
// Reset the parameters
|
||||
@ -102,7 +102,7 @@ extern "C" void run_mha(
|
||||
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
||||
params.is_bf16 = 0;
|
||||
params.is_bf16 = is_bf16;
|
||||
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
||||
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
||||
params.p_ptr = nullptr; // used for `return_softmax`.
|
||||
|
@ -38,6 +38,7 @@ extern "C" {
|
||||
seqlen_k_rounded: u32,
|
||||
|
||||
is_causal: c_int,
|
||||
is_bf16: c_int,
|
||||
);
|
||||
|
||||
}
|
||||
|
@ -146,6 +146,7 @@ impl candle::CustomOp3 for FlashAttn {
|
||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_causal */ causal,
|
||||
/* is_bf16 */ 0,
|
||||
)
|
||||
}
|
||||
|
||||
@ -354,6 +355,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
|
||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_causal */ causal,
|
||||
/* is_bf16 */ 0,
|
||||
)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user