mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
3 Commits
0.9.0-alph
...
wasm-llama
Author | SHA1 | Date | |
---|---|---|---|
b2e4beb4f3 | |||
d48bddbe01 | |||
145706f8df |
@ -12,6 +12,7 @@ license.workspace = true
|
||||
candle = { path = "../../candle-core", version = "0.1.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.1.0" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
# App crates.
|
||||
anyhow = { workspace = true }
|
||||
|
@ -4,7 +4,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<title>Welcome to Candle!</title>
|
||||
|
||||
<link data-trunk rel="copy-file" href="tokenizer.bin" />
|
||||
<link data-trunk rel="copy-file" href="tokenizer.json" />
|
||||
<link data-trunk rel="copy-file" href="model.bin" />
|
||||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
|
||||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
|
||||
|
@ -46,6 +46,7 @@ pub struct App {
|
||||
status: String,
|
||||
loaded: bool,
|
||||
temperature: std::rc::Rc<std::cell::RefCell<f64>>,
|
||||
prompt: std::rc::Rc<std::cell::RefCell<String>>,
|
||||
generated: String,
|
||||
n_tokens: usize,
|
||||
current_decode: Option<CurrentDecode>,
|
||||
@ -53,7 +54,7 @@ pub struct App {
|
||||
}
|
||||
|
||||
async fn model_data_load() -> Result<ModelData, JsValue> {
|
||||
let tokenizer = fetch_url("tokenizer.bin").await?;
|
||||
let tokenizer = fetch_url("tokenizer.json").await?;
|
||||
let model = fetch_url("model.bin").await?;
|
||||
console_log!("{}", model.len());
|
||||
Ok(ModelData { tokenizer, model })
|
||||
@ -80,6 +81,7 @@ impl Component for App {
|
||||
status,
|
||||
n_tokens: 0,
|
||||
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
|
||||
prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())),
|
||||
generated: String::new(),
|
||||
current_decode: None,
|
||||
worker,
|
||||
@ -120,9 +122,10 @@ impl Component for App {
|
||||
self.n_tokens = 0;
|
||||
self.generated.clear();
|
||||
let temp = *self.temperature.borrow();
|
||||
console_log!("temp: {}", temp);
|
||||
let prompt = self.prompt.borrow().clone();
|
||||
console_log!("temp: {}, prompt: {}", temp, prompt);
|
||||
ctx.link()
|
||||
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp)))
|
||||
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt)))
|
||||
}
|
||||
true
|
||||
}
|
||||
@ -181,6 +184,12 @@ impl Component for App {
|
||||
}
|
||||
Msg::Refresh
|
||||
});
|
||||
let prompt = self.prompt.clone();
|
||||
let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| {
|
||||
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
|
||||
*prompt.borrow_mut() = input.value();
|
||||
Msg::Refresh
|
||||
});
|
||||
html! {
|
||||
<div style="margin: 2%;">
|
||||
<div><p>{"Running "}
|
||||
@ -195,6 +204,8 @@ impl Component for App {
|
||||
<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
|
||||
{format!(" \u{00a0} {}", self.temperature.borrow())}
|
||||
<br/ >
|
||||
{"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/>
|
||||
<br/ >
|
||||
{
|
||||
if self.loaded{
|
||||
html!(<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>)
|
||||
|
@ -1,28 +1,3 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
pub const WITH_TIMER: bool = true;
|
||||
|
||||
struct Timer {
|
||||
label: &'static str,
|
||||
}
|
||||
|
||||
impl Timer {
|
||||
fn new(label: &'static str) -> Self {
|
||||
if WITH_TIMER {
|
||||
web_sys::console::time_with_label(label);
|
||||
}
|
||||
Self { label }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Timer {
|
||||
fn drop(&mut self) {
|
||||
if WITH_TIMER {
|
||||
web_sys::console::time_end_with_label(self.label)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod app;
|
||||
mod model;
|
||||
mod worker;
|
||||
|
@ -106,14 +106,15 @@ struct CausalSelfAttention {
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
max_seq_len: usize,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||
let cos = cos.unsqueeze(1)?;
|
||||
let sin = sin.unsqueeze(1)?;
|
||||
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
||||
@ -196,7 +197,6 @@ impl CausalSelfAttention {
|
||||
n_key_value_head: cfg.n_kv_heads,
|
||||
head_dim: cfg.dim / cfg.n_heads,
|
||||
cache: cache.clone(),
|
||||
max_seq_len: cfg.seq_len,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use yew_agent::{HandlerId, Public, WorkerLink};
|
||||
|
||||
@ -48,23 +49,6 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>(
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
struct Tokenizer {
|
||||
tokens: Vec<String>,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> {
|
||||
let mut tokens = Vec::with_capacity(c.vocab_size);
|
||||
for _token_index in 0..c.vocab_size {
|
||||
let token_len = read_i32(r)?;
|
||||
let mut token = vec![0u8; token_len as usize];
|
||||
r.read_exact(&mut token)?;
|
||||
tokens.push(String::from_utf8_lossy(&token).into_owned())
|
||||
}
|
||||
Ok(Self { tokens })
|
||||
}
|
||||
}
|
||||
|
||||
struct Model {
|
||||
cache: Cache,
|
||||
config: Config,
|
||||
@ -107,13 +91,25 @@ impl LogitsProcessor {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn run(&self, link: &WorkerLink<Worker>, id: HandlerId, temp: f64) -> Result<()> {
|
||||
fn run(
|
||||
&self,
|
||||
link: &WorkerLink<Worker>,
|
||||
id: HandlerId,
|
||||
temp: f64,
|
||||
prompt: String,
|
||||
) -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
console_log!("{temp:?}");
|
||||
console_log!("{temp:?} {prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temp);
|
||||
let mut index_pos = 0;
|
||||
let mut tokens = vec![1u32];
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt.to_string(), true)
|
||||
.map_err(|m| candle::Error::Msg(m.to_string()))?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
link.respond(id, Ok(WorkerOutput::Generated(prompt)));
|
||||
|
||||
for index in 0..self.config.seq_len - 10 {
|
||||
let context_size = if self.cache.use_kv_cache && index > 0 {
|
||||
@ -129,8 +125,10 @@ impl Model {
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
let token = self.tokenizer.tokens[next_token as usize].clone();
|
||||
link.respond(id, Ok(WorkerOutput::Generated(token)));
|
||||
if let Some(text) = self.tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
link.respond(id, Ok(WorkerOutput::Generated(text)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -282,8 +280,8 @@ impl Model {
|
||||
let vb = weights.var_builder(&config, &dev)?;
|
||||
let cache = Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let llama = Llama::load(vb, &cache, &config)?;
|
||||
let mut tokenizer = std::io::Cursor::new(md.tokenizer);
|
||||
let tokenizer = Tokenizer::from_reader(&mut tokenizer, &config)?;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&md.tokenizer).map_err(|m| candle::Error::Msg(m.to_string()))?;
|
||||
Ok(Self {
|
||||
cache,
|
||||
config,
|
||||
@ -301,7 +299,7 @@ pub struct Worker {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum WorkerInput {
|
||||
ModelData(ModelData),
|
||||
Run(f64),
|
||||
Run(f64, String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@ -334,7 +332,7 @@ impl yew_agent::Worker for Worker {
|
||||
}
|
||||
Err(err) => Err(format!("model creation error {err:?}")),
|
||||
},
|
||||
WorkerInput::Run(temp) => match &mut self.model {
|
||||
WorkerInput::Run(temp, prompt) => match &mut self.model {
|
||||
None => Err("model has not been set yet".to_string()),
|
||||
Some(model) => {
|
||||
{
|
||||
@ -343,7 +341,9 @@ impl yew_agent::Worker for Worker {
|
||||
*elem = None
|
||||
}
|
||||
}
|
||||
let result = model.run(&self.link, id, temp).map_err(|e| e.to_string());
|
||||
let result = model
|
||||
.run(&self.link, id, temp, prompt)
|
||||
.map_err(|e| e.to_string());
|
||||
Ok(WorkerOutput::GenerationDone(result))
|
||||
}
|
||||
},
|
||||
|
Reference in New Issue
Block a user