Remove single function.

This commit is contained in:
Nicolas Patry
2023-07-28 10:26:41 +00:00
parent 97181a77c0
commit 97d8712ba5

View File

@ -30,16 +30,6 @@ struct AllReduce {
comm: Rc<Comm>,
}
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Sync for AllReduce {}
@ -306,7 +296,8 @@ impl CausalSelfAttention {
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
let y = flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?;
let y =
candle_flash_attn::flash_attn(q, k, v, softmax_scale, seq_len > 1)?.transpose(1, 2)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;