Fixed matmul (display still broken without casting back to CPU first? )

This commit is contained in:
Nicolas Patry
2023-11-10 20:09:25 +01:00
committed by Nicolas Patry
parent d46670f7c0
commit 38de52bc4b
4 changed files with 127 additions and 111 deletions

View File

@ -156,6 +156,7 @@ impl CausalSelfAttention {
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
todo!("X {x1}");
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
@ -165,7 +166,6 @@ impl CausalSelfAttention {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
todo!("X {q}");
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
@ -174,6 +174,7 @@ impl CausalSelfAttention {
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos)?;
todo!("X {q}");
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if self.cache.use_kv_cache {