mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Clean-up the llama2.c wasm example.
This commit is contained in:
@ -1,28 +1,3 @@
|
|||||||
#![allow(dead_code)]
|
|
||||||
|
|
||||||
pub const WITH_TIMER: bool = true;
|
|
||||||
|
|
||||||
struct Timer {
|
|
||||||
label: &'static str,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Timer {
|
|
||||||
fn new(label: &'static str) -> Self {
|
|
||||||
if WITH_TIMER {
|
|
||||||
web_sys::console::time_with_label(label);
|
|
||||||
}
|
|
||||||
Self { label }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for Timer {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
if WITH_TIMER {
|
|
||||||
web_sys::console::time_end_with_label(self.label)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mod app;
|
mod app;
|
||||||
mod model;
|
mod model;
|
||||||
mod worker;
|
mod worker;
|
||||||
|
@ -106,14 +106,15 @@ struct CausalSelfAttention {
|
|||||||
n_key_value_head: usize,
|
n_key_value_head: usize,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
cache: Cache,
|
cache: Cache,
|
||||||
max_seq_len: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
||||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||||
|
let cos = cos.unsqueeze(1)?;
|
||||||
|
let sin = sin.unsqueeze(1)?;
|
||||||
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||||
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||||
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
||||||
@ -196,7 +197,6 @@ impl CausalSelfAttention {
|
|||||||
n_key_value_head: cfg.n_kv_heads,
|
n_key_value_head: cfg.n_kv_heads,
|
||||||
head_dim: cfg.dim / cfg.n_heads,
|
head_dim: cfg.dim / cfg.n_heads,
|
||||||
cache: cache.clone(),
|
cache: cache.clone(),
|
||||||
max_seq_len: cfg.seq_len,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user