mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Again set a few extra params in flash-attn. (#245)
* Again set a few extra params. * Use the appropriate kernel sizes. * Add all the kernel sizes. * Parallel compiling. * Reduce the amount of parallelism. * Add the missing kernel. * Fix a typo. * Remove bf16 support for now.
This commit is contained in:
@ -220,8 +220,12 @@ impl CausalSelfAttention {
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let y = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
let q = q.transpose(1, 2)?;
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(softmax_scale, &q, &k, &v)?
|
||||
flash_attn(softmax_scale, &q, &k, &v)?.transpose(1, 2)?
|
||||
} else {
|
||||
let in_dtype = q.dtype();
|
||||
let q = q.to_dtype(DType::F32)?;
|
||||
|
Reference in New Issue
Block a user