From 2b93dffe64d26829224f0f31e81f6c50c0e1e733 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 18 Apr 2024 22:34:29 +0200 Subject: [PATCH] Use faster rotary embeddings for llama like models. (#2087) --- candle-transformers/src/models/llama.rs | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 97a40d37..945c0e17 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -116,7 +116,6 @@ impl Cache { .matmul(&theta.reshape((1, theta.elem_count()))?)?; // This is different from the paper, see: // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 - let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; let cos = idx_theta.cos()?.to_dtype(dtype)?; let sin = idx_theta.sin()?.to_dtype(dtype)?; Ok(Self { @@ -176,16 +175,10 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result Result { let _enter = self.span_rot.enter(); - let (b_sz, _, seq_len, hidden_size) = x.dims4()?; + let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?; let cos = cache.cos.narrow(0, index_pos, seq_len)?; let sin = cache.sin.narrow(0, index_pos, seq_len)?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?; - let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?; - let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?; - let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; - let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; - Ok(rope) + candle_nn::rotary_emb::rope(x, &cos, &sin) } fn forward( @@ -203,10 +196,12 @@ impl CausalSelfAttention { let q = q .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let k = k .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let mut v = v .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? .transpose(1, 2)?;