mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00

* Proper flash-attn parameters. * Set the flash attention parameters. * Add more validations. * Setup the o_ flash attn parameters. * More flash-attn support. * Set more flash attn parameters.
42 lines
837 B
Rust
42 lines
837 B
Rust
use core::ffi::{c_int, c_void};
|
|
|
|
extern "C" {
|
|
pub(crate) fn run_mha(
|
|
q_ptr: *const c_void,
|
|
k_ptr: *const c_void,
|
|
v_ptr: *const c_void,
|
|
o_ptr: *const c_void,
|
|
softmax_lse_ptr: *const c_void,
|
|
|
|
q_batch_stride: u32,
|
|
k_batch_stride: u32,
|
|
v_batch_stride: u32,
|
|
o_batch_stride: u32,
|
|
|
|
q_row_stride: u32,
|
|
k_row_stride: u32,
|
|
v_row_stride: u32,
|
|
o_row_stride: u32,
|
|
|
|
q_head_stride: u32,
|
|
k_head_stride: u32,
|
|
v_head_stride: u32,
|
|
o_head_stride: u32,
|
|
|
|
b: u32,
|
|
h: u32,
|
|
h_k: u32,
|
|
d: u32,
|
|
d_rounded: u32,
|
|
softmax_scale: f32,
|
|
|
|
seqlen_q: u32,
|
|
seqlen_k: u32,
|
|
seqlen_q_rounded: u32,
|
|
seqlen_k_rounded: u32,
|
|
|
|
is_causal: c_int,
|
|
);
|
|
|
|
}
|