mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Fix the flash-attention function names. (#282)
This commit is contained in:
@ -17,7 +17,7 @@ fn round_multiple(x: usize, m: usize) -> usize {
|
||||
|
||||
impl candle::CustomOp3 for FlashAttn {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-hdim32-sm80"
|
||||
"flash-attn"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
@ -192,7 +192,7 @@ struct FlashAttnVarLen {
|
||||
|
||||
impl candle::CustomOp3 for FlashAttnVarLen {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-hdim32-sm80"
|
||||
"flash-attn-varlen"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
|
Reference in New Issue
Block a user