mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Again set a few extra params in flash-attn. (#245)
* Again set a few extra params. * Use the appropriate kernel sizes. * Add all the kernel sizes. * Parallel compiling. * Reduce the amount of parallelism. * Add the missing kernel. * Fix a typo. * Remove bf16 support for now.
This commit is contained in:
@ -118,14 +118,14 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
|
||||
/* k_batch_stride */ k_stride[0] as u32,
|
||||
/* v_batch_stride */ v_stride[0] as u32,
|
||||
/* o_batch_stride */ o_stride[0] as u32,
|
||||
/* q_row_stride */ q_stride[q_rank - 3] as u32,
|
||||
/* k_row_stride */ k_stride[k_rank - 3] as u32,
|
||||
/* v_row_stride */ v_stride[v_rank - 3] as u32,
|
||||
/* o_row_stride */ o_stride[o_rank - 3] as u32,
|
||||
/* q_head_stride */ q_stride[q_rank - 2] as u32,
|
||||
/* k_head_stride */ k_stride[k_rank - 2] as u32,
|
||||
/* v_head_stride */ v_stride[v_rank - 2] as u32,
|
||||
/* o_head_stride */ o_stride[o_rank - 2] as u32,
|
||||
/* q_row_stride */ q_stride[q_rank - 3] as u32,
|
||||
/* k_row_stride */ k_stride[k_rank - 3] as u32,
|
||||
/* v_row_stride */ v_stride[v_rank - 3] as u32,
|
||||
/* o_row_stride */ o_stride[o_rank - 3] as u32,
|
||||
/* q_head_stride */ q_stride[q_rank - 2] as u32,
|
||||
/* k_head_stride */ k_stride[k_rank - 2] as u32,
|
||||
/* v_head_stride */ v_stride[v_rank - 2] as u32,
|
||||
/* o_head_stride */ o_stride[o_rank - 2] as u32,
|
||||
/* b */ b_sz as u32,
|
||||
/* h */ num_heads as u32,
|
||||
/* h_k */ num_heads_k as u32,
|
||||
|
Reference in New Issue
Block a user