Llama2.c wasm module. (#686)

This commit is contained in:
Laurent Mazare
2023-08-31 08:44:32 +02:00
committed by GitHub
parent 9bd486fb96
commit 8e84d8a59b
3 changed files with 90 additions and 7 deletions

View File

@ -0,0 +1,83 @@
use candle::{Device, Tensor};
use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct Model {
inner: M,
logits_processor: LogitsProcessor,
tokens: Vec<u32>,
}
impl Model {
fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
let dev = Device::Cpu;
let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
let logits = self.inner.llama.forward(&input, tokens.len())?;
let logits = logits.squeeze(0)?;
let next_token = self.logits_processor.sample(&logits)?;
self.tokens.push(next_token);
let text = match self.inner.tokenizer.id_to_token(next_token) {
Some(text) => text.replace('▁', " ").replace("<0x0A>", "\n"),
None => "".to_string(),
};
Ok(text)
}
}
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
pub fn new(weights: Vec<u8>, tokenizer: Vec<u8>) -> Result<Model, JsError> {
let model = M::load(ModelData {
tokenizer,
model: weights,
});
let logits_processor = LogitsProcessor::new(299792458, None);
match model {
Ok(inner) => Ok(Self {
inner,
logits_processor,
tokens: vec![],
}),
Err(e) => Err(JsError::new(&e.to_string())),
}
}
#[wasm_bindgen]
pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result<String, JsError> {
// First reset the cache.
{
let mut cache = self.inner.cache.kvs.lock().unwrap();
for elem in cache.iter_mut() {
*elem = None
}
}
let temp = if temp <= 0. { None } else { Some(temp) };
self.logits_processor = LogitsProcessor::new(299792458, temp);
self.tokens.clear();
let tokens = self
.inner
.tokenizer
.encode(prompt.to_string(), true)
.map_err(|m| JsError::new(&m.to_string()))?
.get_ids()
.to_vec();
let text = self
.process(&tokens)
.map_err(|m| JsError::new(&m.to_string()))?;
Ok(text)
}
#[wasm_bindgen]
pub fn next_token(&mut self) -> Result<String, JsError> {
let last_token = *self.tokens.last().unwrap();
let text = self
.process(&[last_token])
.map_err(|m| JsError::new(&m.to_string()))?;
Ok(text)
}
}
fn main() {}

View File

@ -1,5 +1,5 @@
mod app;
mod model;
mod worker;
pub mod model;
pub mod worker;
pub use app::App;
pub use worker::Worker;

View File

@ -49,11 +49,11 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>(
Ok(tensor)
}
struct Model {
cache: Cache,
pub struct Model {
pub cache: Cache,
config: Config,
llama: Llama,
tokenizer: Tokenizer,
pub llama: Llama,
pub tokenizer: Tokenizer,
}
pub struct LogitsProcessor {
@ -275,7 +275,7 @@ impl TransformerWeights {
}
impl Model {
fn load(md: ModelData) -> Result<Self> {
pub fn load(md: ModelData) -> Result<Self> {
let dev = Device::Cpu;
let mut model = std::io::Cursor::new(md.model);
let config = Config::from_reader(&mut model)?;