Tmp state.

This commit is contained in:
Nicolas Patry
2023-11-10 15:35:46 +01:00
committed by Nicolas Patry
parent f710fab02e
commit d46670f7c0
14 changed files with 699 additions and 63 deletions

View File

@ -165,6 +165,7 @@ impl CausalSelfAttention {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
todo!("X {q}");
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
@ -295,6 +296,7 @@ impl Block {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
todo!("---X {}", x);
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
@ -327,6 +329,7 @@ impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
//println!("Embeddings {}", self.wte.embeddings());
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;
}