Proper flash-attn parameters. (#244)

* Proper flash-attn parameters.

* Set the flash attention parameters.

* Add more validations.

* Setup the o_ flash attn parameters.

* More flash-attn support.

* Set more flash attn parameters.
This commit is contained in:
Laurent Mazare
2023-07-26 10:13:40 +01:00
committed by GitHub
parent e40b150bbe
commit fa2b64d678
5 changed files with 147 additions and 12 deletions

View File

@ -6,16 +6,22 @@ extern "C" {
k_ptr: *const c_void,
v_ptr: *const c_void,
o_ptr: *const c_void,
softmax_lse_ptr: *const c_void,
q_batch_stride: u32,
k_batch_stride: u32,
v_batch_stride: u32,
o_batch_stride: u32,
q_row_stride: u32,
k_row_stride: u32,
v_row_stride: u32,
o_row_stride: u32,
q_head_stride: u32,
k_head_stride: u32,
v_head_stride: u32,
o_head_stride: u32,
b: u32,
h: u32,