From 60a5598c8b9e8205f48385c6e5cf2dc26bdab8cd Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 25 Jun 2023 17:56:59 +0100 Subject: [PATCH] Fix some shape errors. --- examples/llama/main.rs | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/llama/main.rs b/examples/llama/main.rs index 54a02079..473cdb08 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -176,10 +176,11 @@ impl RmsNorm { } fn forward(&self, x: &Tensor) -> Result { - let last_dim = x.dims().last().unwrap(); - let norm_x = ((x * x)?.sum(&[x.rank() - 1])? / *last_dim as f64)?; + let (seq_len, hidden_size) = x.shape().r2()?; + let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; + let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; - let scale = self.scale.reshape(&[1, 1, self.size])?; + let scale = self.scale.broadcast_as((seq_len, self.size))?; Ok((scale * x_normed)?) } } @@ -266,16 +267,16 @@ impl CausalSelfAttention { } fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { - let (b, t, c) = x.shape().r3()?; + let (t, c) = x.shape().r2()?; let qkv = self.c_attn.forward(x)?; let n_embd = self.n_embd; - let q = qkv.narrow(2, 0, n_embd)?; - let k = qkv.narrow(2, n_embd, 2 * n_embd)?; - let v = qkv.narrow(2, 2 * n_embd, 3 * n_embd)?; - let target_dim = [b, t, self.n_head, c / self.n_head]; - let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?; - let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?; - let v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?; + let q = qkv.narrow(1, 0, n_embd)?; + let k = qkv.narrow(1, n_embd, n_embd)?; + let v = qkv.narrow(1, 2 * n_embd, n_embd)?; + let target_dim = [t, self.n_head, c / self.n_head]; + let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?; + let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?; + let v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?; let q = self.apply_rotary_emb(&q, freqs_cis)?; let k = self.apply_rotary_emb(&k, freqs_cis)?; let k_shape = k.shape(); @@ -288,7 +289,7 @@ impl CausalSelfAttention { .reshape(&[1, 1, t, t])?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let y = att.softmax(att.rank() - 1)?.matmul(&v)?; - let y = y.transpose(1, 2)?.reshape(&[b, t, c])?; + let y = y.transpose(1, 2)?.reshape(&[t, c])?; let y = self.c_proj.forward(&y)?; Ok(y) }