mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
More llama fixes.
This commit is contained in:
@ -251,9 +251,13 @@ impl CausalSelfAttention {
|
||||
let x = x.reshape(dims)?;
|
||||
let rank = x.rank();
|
||||
let re_x = x.narrow(rank - 1, 0, 1)?;
|
||||
let im_x = x.narrow(rank - 1, 1, 2)?;
|
||||
let re_f = freqs_cis.narrow(rank - 1, 0, 1)?;
|
||||
let im_f = freqs_cis.narrow(rank - 1, 1, 2)?;
|
||||
let im_x = x.narrow(rank - 1, 1, 1)?;
|
||||
let re_f = freqs_cis
|
||||
.narrow(rank - 1, 0, 1)?
|
||||
.broadcast_as(re_x.shape())?;
|
||||
let im_f = freqs_cis
|
||||
.narrow(rank - 1, 1, 1)?
|
||||
.broadcast_as(im_x.shape())?;
|
||||
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
|
||||
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
|
||||
let rope = Tensor::cat(&[&re, &im], rank - 1)?;
|
||||
@ -288,7 +292,9 @@ impl CausalSelfAttention {
|
||||
// TODO: .lower_triangle()?
|
||||
.reshape(&[1, 1, t, t])?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let y = att.softmax(att.rank() - 1)?.matmul(&v)?;
|
||||
let att = att.softmax(att.rank() - 1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[t, c])?;
|
||||
let y = self.c_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
@ -358,7 +364,7 @@ impl Llama {
|
||||
x = block.forward(&x, freqs_cis)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.narrow(1, t - 1, t)?;
|
||||
let x = x.narrow(0, t - 1, 1)?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
Ok(logits)
|
||||
}
|
||||
@ -377,7 +383,7 @@ fn precompute_freqs_cis(config: &Config) -> Result<Tensor> {
|
||||
let idx_theta = arange
|
||||
.reshape((arange.elem_count(), 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let shape = [1, 1, seq_len, n_elem / 2, 1];
|
||||
let shape = [1, seq_len, n_elem / 2, 1];
|
||||
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
|
||||
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
|
||||
let last_dim = idx_theta_cos.rank() - 1;
|
||||
|
Reference in New Issue
Block a user