Fix the flash-attention function names. (#282)

This commit is contained in:
Laurent Mazare
2023-07-31 10:04:39 +01:00
committed by GitHub
parent 0ace420e66
commit 67834119fc

View File

@ -17,7 +17,7 @@ fn round_multiple(x: usize, m: usize) -> usize {
impl candle::CustomOp3 for FlashAttn {
fn name(&self) -> &'static str {
"flash-hdim32-sm80"
"flash-attn"
}
fn cpu_fwd(
@ -192,7 +192,7 @@ struct FlashAttnVarLen {
impl candle::CustomOp3 for FlashAttnVarLen {
fn name(&self) -> &'static str {
"flash-hdim32-sm80"
"flash-attn-varlen"
}
fn cpu_fwd(