Again set a few extra params in flash-attn. (#245)

* Again set a few extra params.

* Use the appropriate kernel sizes.

* Add all the kernel sizes.

* Parallel compiling.

* Reduce the amount of parallelism.

* Add the missing kernel.

* Fix a typo.

* Remove bf16 support for now.
This commit is contained in:
Laurent Mazare
2023-07-26 14:16:37 +01:00
committed by GitHub
parent fa2b64d678
commit 2ce5f12513
21 changed files with 476 additions and 116 deletions

View File

@ -118,14 +118,14 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
/* k_batch_stride */ k_stride[0] as u32,
/* v_batch_stride */ v_stride[0] as u32,
/* o_batch_stride */ o_stride[0] as u32,
/* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32,
/* o_row_stride */ o_stride[o_rank - 3] as u32,
/* q_head_stride */ q_stride[q_rank - 2] as u32,
/* k_head_stride */ k_stride[k_rank - 2] as u32,
/* v_head_stride */ v_stride[v_rank - 2] as u32,
/* o_head_stride */ o_stride[o_rank - 2] as u32,
/* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32,
/* o_row_stride */ o_stride[o_rank - 3] as u32,
/* q_head_stride */ q_stride[q_rank - 2] as u32,
/* k_head_stride */ k_stride[k_rank - 2] as u32,
/* v_head_stride */ v_stride[v_rank - 2] as u32,
/* o_head_stride */ o_stride[o_rank - 2] as u32,
/* b */ b_sz as u32,
/* h */ num_heads as u32,
/* h_k */ num_heads_k as u32,