mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Fix the flash-attention function names. (#282)
This commit is contained in:
@ -17,7 +17,7 @@ fn round_multiple(x: usize, m: usize) -> usize {
|
|||||||
|
|
||||||
impl candle::CustomOp3 for FlashAttn {
|
impl candle::CustomOp3 for FlashAttn {
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> &'static str {
|
||||||
"flash-hdim32-sm80"
|
"flash-attn"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_fwd(
|
fn cpu_fwd(
|
||||||
@ -192,7 +192,7 @@ struct FlashAttnVarLen {
|
|||||||
|
|
||||||
impl candle::CustomOp3 for FlashAttnVarLen {
|
impl candle::CustomOp3 for FlashAttnVarLen {
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> &'static str {
|
||||||
"flash-hdim32-sm80"
|
"flash-attn-varlen"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_fwd(
|
fn cpu_fwd(
|
||||||
|
Reference in New Issue
Block a user