From 97d8712ba507dbdb06c639b0c6b8857e454bb269 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 28 Jul 2023 10:26:41 +0000 Subject: [PATCH] Remove single function. --- .../examples/llama_multiprocess/model.rs | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index 6980057f..573eae11 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -30,16 +30,6 @@ struct AllReduce { comm: Rc, } -fn flash_attn( - q: &Tensor, - k: &Tensor, - v: &Tensor, - softmax_scale: f32, - causal: bool, -) -> Result { - candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) -} - /// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html /// But for this example purposes, this will work unsafe impl Sync for AllReduce {} @@ -306,7 +296,8 @@ impl CausalSelfAttention { let k = k.transpose(1, 2)?; let v = v.transpose(1, 2)?; let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); - let y = flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?; + let y = + candle_flash_attn::flash_attn(q, k, v, softmax_scale, seq_len > 1)?.transpose(1, 2)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = self.o_proj.forward(&y)?;