mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Tmp state.
This commit is contained in:

committed by
Nicolas Patry

parent
f710fab02e
commit
d46670f7c0
@ -165,6 +165,7 @@ impl CausalSelfAttention {
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
todo!("X {q}");
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
|
||||
@ -295,6 +296,7 @@ impl Block {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
todo!("---X {}", x);
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
@ -327,6 +329,7 @@ impl Llama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, _seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
//println!("Embeddings {}", self.wte.embeddings());
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
|
Reference in New Issue
Block a user