mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add back the bf16 flash-attn kernels. (#730)
This commit is contained in:
@ -6,7 +6,7 @@ use rayon::prelude::*;
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
const KERNEL_FILES: [&str; 9] = [
|
const KERNEL_FILES: [&str; 17] = [
|
||||||
"flash_api.cu",
|
"flash_api.cu",
|
||||||
"flash_fwd_hdim128_fp16_sm80.cu",
|
"flash_fwd_hdim128_fp16_sm80.cu",
|
||||||
"flash_fwd_hdim160_fp16_sm80.cu",
|
"flash_fwd_hdim160_fp16_sm80.cu",
|
||||||
@ -16,14 +16,14 @@ const KERNEL_FILES: [&str; 9] = [
|
|||||||
"flash_fwd_hdim32_fp16_sm80.cu",
|
"flash_fwd_hdim32_fp16_sm80.cu",
|
||||||
"flash_fwd_hdim64_fp16_sm80.cu",
|
"flash_fwd_hdim64_fp16_sm80.cu",
|
||||||
"flash_fwd_hdim96_fp16_sm80.cu",
|
"flash_fwd_hdim96_fp16_sm80.cu",
|
||||||
// "flash_fwd_hdim128_bf16_sm80.cu",
|
"flash_fwd_hdim128_bf16_sm80.cu",
|
||||||
// "flash_fwd_hdim160_bf16_sm80.cu",
|
"flash_fwd_hdim160_bf16_sm80.cu",
|
||||||
// "flash_fwd_hdim192_bf16_sm80.cu",
|
"flash_fwd_hdim192_bf16_sm80.cu",
|
||||||
// "flash_fwd_hdim224_bf16_sm80.cu",
|
"flash_fwd_hdim224_bf16_sm80.cu",
|
||||||
// "flash_fwd_hdim256_bf16_sm80.cu",
|
"flash_fwd_hdim256_bf16_sm80.cu",
|
||||||
// "flash_fwd_hdim32_bf16_sm80.cu",
|
"flash_fwd_hdim32_bf16_sm80.cu",
|
||||||
// "flash_fwd_hdim64_bf16_sm80.cu",
|
"flash_fwd_hdim64_bf16_sm80.cu",
|
||||||
// "flash_fwd_hdim96_bf16_sm80.cu",
|
"flash_fwd_hdim96_bf16_sm80.cu",
|
||||||
];
|
];
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
@ -1,20 +1,19 @@
|
|||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// TODO: Switch back to handling bf16.
|
|
||||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
|
||||||
run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
// FP16_SWITCH(!params.is_bf16, [&] {
|
// FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||||
// FWD_HEADDIM_SWITCH(params.d, [&] {
|
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
|
||||||
// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
|
||||||
// });
|
|
||||||
// });
|
// });
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
FP16_SWITCH(!params.is_bf16, [&] {
|
||||||
|
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||||
|
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" void run_mha(
|
extern "C" void run_mha(
|
||||||
void *q_ptr,
|
void *q_ptr,
|
||||||
void *k_ptr,
|
void *k_ptr,
|
||||||
@ -52,7 +51,8 @@ extern "C" void run_mha(
|
|||||||
uint32_t seqlen_q_rounded,
|
uint32_t seqlen_q_rounded,
|
||||||
uint32_t seqlen_k_rounded,
|
uint32_t seqlen_k_rounded,
|
||||||
|
|
||||||
int is_causal
|
int is_causal,
|
||||||
|
int is_bf16
|
||||||
) {
|
) {
|
||||||
Flash_fwd_params params;
|
Flash_fwd_params params;
|
||||||
// Reset the parameters
|
// Reset the parameters
|
||||||
@ -102,7 +102,7 @@ extern "C" void run_mha(
|
|||||||
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
||||||
params.rp_dropout = 1.f / params.p_dropout;
|
params.rp_dropout = 1.f / params.p_dropout;
|
||||||
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
||||||
params.is_bf16 = 0;
|
params.is_bf16 = is_bf16;
|
||||||
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
||||||
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
||||||
params.p_ptr = nullptr; // used for `return_softmax`.
|
params.p_ptr = nullptr; // used for `return_softmax`.
|
||||||
|
@ -38,6 +38,7 @@ extern "C" {
|
|||||||
seqlen_k_rounded: u32,
|
seqlen_k_rounded: u32,
|
||||||
|
|
||||||
is_causal: c_int,
|
is_causal: c_int,
|
||||||
|
is_bf16: c_int,
|
||||||
);
|
);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -146,6 +146,7 @@ impl candle::CustomOp3 for FlashAttn {
|
|||||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||||
/* is_causal */ causal,
|
/* is_causal */ causal,
|
||||||
|
/* is_bf16 */ 0,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -354,6 +355,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
|
|||||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||||
/* is_causal */ causal,
|
/* is_causal */ causal,
|
||||||
|
/* is_bf16 */ 0,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user