mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Remove single function.
This commit is contained in:
@ -30,16 +30,6 @@ struct AllReduce {
|
||||
comm: Rc<Comm>,
|
||||
}
|
||||
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||
/// But for this example purposes, this will work
|
||||
unsafe impl Sync for AllReduce {}
|
||||
@ -306,7 +296,8 @@ impl CausalSelfAttention {
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
let y = flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?;
|
||||
let y =
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, seq_len > 1)?.transpose(1, 2)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
|
Reference in New Issue
Block a user