Add back the bf16 flash-attn kernels. (#730)

This commit is contained in:
Laurent Mazare
2023-09-04 08:50:52 +02:00
committed by GitHub
parent 20512ba408
commit d0cdea95a5
4 changed files with 25 additions and 22 deletions

View File

@ -146,6 +146,7 @@ impl candle::CustomOp3 for FlashAttn {
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ 0,
)
}
@ -354,6 +355,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ 0,
)
}