mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Remove single function.
This commit is contained in:
@ -30,16 +30,6 @@ struct AllReduce {
|
|||||||
comm: Rc<Comm>,
|
comm: Rc<Comm>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flash_attn(
|
|
||||||
q: &Tensor,
|
|
||||||
k: &Tensor,
|
|
||||||
v: &Tensor,
|
|
||||||
softmax_scale: f32,
|
|
||||||
causal: bool,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||||
/// But for this example purposes, this will work
|
/// But for this example purposes, this will work
|
||||||
unsafe impl Sync for AllReduce {}
|
unsafe impl Sync for AllReduce {}
|
||||||
@ -306,7 +296,8 @@ impl CausalSelfAttention {
|
|||||||
let k = k.transpose(1, 2)?;
|
let k = k.transpose(1, 2)?;
|
||||||
let v = v.transpose(1, 2)?;
|
let v = v.transpose(1, 2)?;
|
||||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||||
let y = flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?;
|
let y =
|
||||||
|
candle_flash_attn::flash_attn(q, k, v, softmax_scale, seq_len > 1)?.transpose(1, 2)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||||
let y = self.o_proj.forward(&y)?;
|
let y = self.o_proj.forward(&y)?;
|
||||||
|
Reference in New Issue
Block a user