Cleanup fixed a few ops removed debugging scaffolding.

This commit is contained in:
Nicolas Patry
2023-11-10 23:00:32 +01:00
parent 61a070fdd1
commit d840838e95
7 changed files with 28 additions and 55 deletions

View File

@ -156,7 +156,6 @@ impl CausalSelfAttention {
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
todo!("X {x1}");
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
@ -174,7 +173,6 @@ impl CausalSelfAttention {
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos)?;
todo!("X {q}");
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if self.cache.use_kv_cache {
@ -297,7 +295,6 @@ impl Block {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
todo!("---X {}", x);
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
@ -330,7 +327,6 @@ impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
//println!("Embeddings {}", self.wte.embeddings());
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;
}