More llama fixes.

This commit is contained in:
laurent
2023-06-25 18:08:41 +01:00
parent 60a5598c8b
commit 87c5aab005

View File

@ -251,9 +251,13 @@ impl CausalSelfAttention {
let x = x.reshape(dims)?;
let rank = x.rank();
let re_x = x.narrow(rank - 1, 0, 1)?;
let im_x = x.narrow(rank - 1, 1, 2)?;
let re_f = freqs_cis.narrow(rank - 1, 0, 1)?;
let im_f = freqs_cis.narrow(rank - 1, 1, 2)?;
let im_x = x.narrow(rank - 1, 1, 1)?;
let re_f = freqs_cis
.narrow(rank - 1, 0, 1)?
.broadcast_as(re_x.shape())?;
let im_f = freqs_cis
.narrow(rank - 1, 1, 1)?
.broadcast_as(im_x.shape())?;
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
let rope = Tensor::cat(&[&re, &im], rank - 1)?;
@ -288,7 +292,9 @@ impl CausalSelfAttention {
// TODO: .lower_triangle()?
.reshape(&[1, 1, t, t])?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let y = att.softmax(att.rank() - 1)?.matmul(&v)?;
let att = att.softmax(att.rank() - 1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[t, c])?;
let y = self.c_proj.forward(&y)?;
Ok(y)
@ -358,7 +364,7 @@ impl Llama {
x = block.forward(&x, freqs_cis)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.narrow(1, t - 1, t)?;
let x = x.narrow(0, t - 1, 1)?;
let logits = self.lm_head.forward(&x)?;
Ok(logits)
}
@ -377,7 +383,7 @@ fn precompute_freqs_cis(config: &Config) -> Result<Tensor> {
let idx_theta = arange
.reshape((arange.elem_count(), 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let shape = [1, 1, seq_len, n_elem / 2, 1];
let shape = [1, seq_len, n_elem / 2, 1];
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
let last_dim = idx_theta_cos.rank() - 1;