mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab. * More flash attn. * Set up the flash attn parameters. * Get things to compile locally. * Move the flash attention files in a different directory. * Build the static C library with nvcc. * Add more flash attention. * Update the build part. * Better caching. * Exclude flash attention from the default workspace. * Put flash-attn behind a feature gate. * Get the flash attn kernel to run. * Move the flags to a more appropriate place. * Enable flash attention in llama. * Use flash attention in llama.
This commit is contained in:
35
candle-flash-attn/src/ffi.rs
Normal file
35
candle-flash-attn/src/ffi.rs
Normal file
@ -0,0 +1,35 @@
|
||||
use core::ffi::{c_int, c_void};
|
||||
|
||||
extern "C" {
|
||||
pub(crate) fn run_mha(
|
||||
q_ptr: *const c_void,
|
||||
k_ptr: *const c_void,
|
||||
v_ptr: *const c_void,
|
||||
o_ptr: *const c_void,
|
||||
|
||||
q_batch_stride: u32,
|
||||
k_batch_stride: u32,
|
||||
v_batch_stride: u32,
|
||||
q_row_stride: u32,
|
||||
k_row_stride: u32,
|
||||
v_row_stride: u32,
|
||||
q_head_stride: u32,
|
||||
k_head_stride: u32,
|
||||
v_head_stride: u32,
|
||||
|
||||
b: u32,
|
||||
h: u32,
|
||||
h_k: u32,
|
||||
d: u32,
|
||||
d_rounded: u32,
|
||||
softmax_scale: f32,
|
||||
|
||||
seqlen_q: u32,
|
||||
seqlen_k: u32,
|
||||
seqlen_q_rounded: u32,
|
||||
seqlen_k_rounded: u32,
|
||||
|
||||
is_causal: c_int,
|
||||
);
|
||||
|
||||
}
|
Reference in New Issue
Block a user