mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Use a proper tokenizer.
This commit is contained in:
@ -12,6 +12,7 @@ license.workspace = true
|
||||
candle = { path = "../../candle-core", version = "0.1.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.1.0" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
# App crates.
|
||||
anyhow = { workspace = true }
|
||||
|
@ -4,7 +4,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<title>Welcome to Candle!</title>
|
||||
|
||||
<link data-trunk rel="copy-file" href="tokenizer.bin" />
|
||||
<link data-trunk rel="copy-file" href="tokenizer.json" />
|
||||
<link data-trunk rel="copy-file" href="model.bin" />
|
||||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
|
||||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
|
||||
|
@ -53,7 +53,7 @@ pub struct App {
|
||||
}
|
||||
|
||||
async fn model_data_load() -> Result<ModelData, JsValue> {
|
||||
let tokenizer = fetch_url("tokenizer.bin").await?;
|
||||
let tokenizer = fetch_url("tokenizer.json").await?;
|
||||
let model = fetch_url("model.bin").await?;
|
||||
console_log!("{}", model.len());
|
||||
Ok(ModelData { tokenizer, model })
|
||||
|
@ -4,6 +4,7 @@ use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use yew_agent::{HandlerId, Public, WorkerLink};
|
||||
|
||||
@ -48,23 +49,6 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>(
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
struct Tokenizer {
|
||||
tokens: Vec<String>,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> {
|
||||
let mut tokens = Vec::with_capacity(c.vocab_size);
|
||||
for _token_index in 0..c.vocab_size {
|
||||
let token_len = read_i32(r)?;
|
||||
let mut token = vec![0u8; token_len as usize];
|
||||
r.read_exact(&mut token)?;
|
||||
tokens.push(String::from_utf8_lossy(&token).into_owned())
|
||||
}
|
||||
Ok(Self { tokens })
|
||||
}
|
||||
}
|
||||
|
||||
struct Model {
|
||||
cache: Cache,
|
||||
config: Config,
|
||||
@ -129,8 +113,10 @@ impl Model {
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
let token = self.tokenizer.tokens[next_token as usize].clone();
|
||||
link.respond(id, Ok(WorkerOutput::Generated(token)));
|
||||
if let Some(text) = self.tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
link.respond(id, Ok(WorkerOutput::Generated(text)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -282,8 +268,8 @@ impl Model {
|
||||
let vb = weights.var_builder(&config, &dev)?;
|
||||
let cache = Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let llama = Llama::load(vb, &cache, &config)?;
|
||||
let mut tokenizer = std::io::Cursor::new(md.tokenizer);
|
||||
let tokenizer = Tokenizer::from_reader(&mut tokenizer, &config)?;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&md.tokenizer).map_err(|m| candle::Error::Msg(m.to_string()))?;
|
||||
Ok(Self {
|
||||
cache,
|
||||
config,
|
||||
|
Reference in New Issue
Block a user