Wasm llama2 tweaks (#309)

* Clean-up the llama2.c wasm example.

* Use a proper tokenizer.
This commit is contained in:
Laurent Mazare
2023-08-02 15:49:43 +01:00
committed by GitHub
parent 4f17290ce0
commit 186c308d51
6 changed files with 14 additions and 52 deletions

View File

@ -106,14 +106,15 @@ struct CausalSelfAttention {
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
max_seq_len: usize,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
@ -196,7 +197,6 @@ impl CausalSelfAttention {
n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
max_seq_len: cfg.seq_len,
})
}
}