Clean-up the llama2.c wasm example.

This commit is contained in:
laurent
2023-08-02 15:20:03 +01:00
parent 4f17290ce0
commit 145706f8df
2 changed files with 4 additions and 29 deletions

View File

@ -1,28 +1,3 @@
#![allow(dead_code)]
pub const WITH_TIMER: bool = true;
struct Timer {
label: &'static str,
}
impl Timer {
fn new(label: &'static str) -> Self {
if WITH_TIMER {
web_sys::console::time_with_label(label);
}
Self { label }
}
}
impl Drop for Timer {
fn drop(&mut self) {
if WITH_TIMER {
web_sys::console::time_end_with_label(self.label)
}
}
}
mod app;
mod model;
mod worker;

View File

@ -106,14 +106,15 @@ struct CausalSelfAttention {
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
max_seq_len: usize,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
@ -196,7 +197,6 @@ impl CausalSelfAttention {
n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
max_seq_len: cfg.seq_len,
})
}
}