diff --git a/examples/llama/main.rs b/examples/llama/main.rs index 473cdb08..c9089e95 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -251,9 +251,13 @@ impl CausalSelfAttention { let x = x.reshape(dims)?; let rank = x.rank(); let re_x = x.narrow(rank - 1, 0, 1)?; - let im_x = x.narrow(rank - 1, 1, 2)?; - let re_f = freqs_cis.narrow(rank - 1, 0, 1)?; - let im_f = freqs_cis.narrow(rank - 1, 1, 2)?; + let im_x = x.narrow(rank - 1, 1, 1)?; + let re_f = freqs_cis + .narrow(rank - 1, 0, 1)? + .broadcast_as(re_x.shape())?; + let im_f = freqs_cis + .narrow(rank - 1, 1, 1)? + .broadcast_as(im_x.shape())?; let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; let rope = Tensor::cat(&[&re, &im], rank - 1)?; @@ -288,7 +292,9 @@ impl CausalSelfAttention { // TODO: .lower_triangle()? .reshape(&[1, 1, t, t])?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let y = att.softmax(att.rank() - 1)?.matmul(&v)?; + let att = att.softmax(att.rank() - 1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[t, c])?; let y = self.c_proj.forward(&y)?; Ok(y) @@ -358,7 +364,7 @@ impl Llama { x = block.forward(&x, freqs_cis)?; } let x = self.ln_f.forward(&x)?; - let x = x.narrow(1, t - 1, t)?; + let x = x.narrow(0, t - 1, 1)?; let logits = self.lm_head.forward(&x)?; Ok(logits) } @@ -377,7 +383,7 @@ fn precompute_freqs_cis(config: &Config) -> Result { let idx_theta = arange .reshape((arange.elem_count(), 1))? .matmul(&theta.reshape((1, theta.elem_count()))?)?; - let shape = [1, 1, seq_len, n_elem / 2, 1]; + let shape = [1, seq_len, n_elem / 2, 1]; let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?; let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?; let last_dim = idx_theta_cos.rank() - 1;