mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Fix some shape errors.
This commit is contained in:
@ -176,10 +176,11 @@ impl RmsNorm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let last_dim = x.dims().last().unwrap();
|
let (seq_len, hidden_size) = x.shape().r2()?;
|
||||||
let norm_x = ((x * x)?.sum(&[x.rank() - 1])? / *last_dim as f64)?;
|
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
||||||
|
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||||
let scale = self.scale.reshape(&[1, 1, self.size])?;
|
let scale = self.scale.broadcast_as((seq_len, self.size))?;
|
||||||
Ok((scale * x_normed)?)
|
Ok((scale * x_normed)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -266,16 +267,16 @@ impl CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
let (b, t, c) = x.shape().r3()?;
|
let (t, c) = x.shape().r2()?;
|
||||||
let qkv = self.c_attn.forward(x)?;
|
let qkv = self.c_attn.forward(x)?;
|
||||||
let n_embd = self.n_embd;
|
let n_embd = self.n_embd;
|
||||||
let q = qkv.narrow(2, 0, n_embd)?;
|
let q = qkv.narrow(1, 0, n_embd)?;
|
||||||
let k = qkv.narrow(2, n_embd, 2 * n_embd)?;
|
let k = qkv.narrow(1, n_embd, n_embd)?;
|
||||||
let v = qkv.narrow(2, 2 * n_embd, 3 * n_embd)?;
|
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
|
||||||
let target_dim = [b, t, self.n_head, c / self.n_head];
|
let target_dim = [t, self.n_head, c / self.n_head];
|
||||||
let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||||
let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||||
let v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
let v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||||
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
||||||
let k = self.apply_rotary_emb(&k, freqs_cis)?;
|
let k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||||
let k_shape = k.shape();
|
let k_shape = k.shape();
|
||||||
@ -288,7 +289,7 @@ impl CausalSelfAttention {
|
|||||||
.reshape(&[1, 1, t, t])?;
|
.reshape(&[1, 1, t, t])?;
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
let y = att.softmax(att.rank() - 1)?.matmul(&v)?;
|
let y = att.softmax(att.rank() - 1)?.matmul(&v)?;
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b, t, c])?;
|
let y = y.transpose(1, 2)?.reshape(&[t, c])?;
|
||||||
let y = self.c_proj.forward(&y)?;
|
let y = self.c_proj.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user