mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Phi 1.5 wasm module (#966)
* add phi wasm module * replace input with textarea * trim input prompt * stop on <|endoftext|> * formatting * clean up * add blurb, and syntax highlighting * add phi-v1.5 wasm * add note * hide Options on details * add first token to generated text * whitespaces for new line * fix: abort -> aborted
This commit is contained in:
139
candle-wasm-examples/phi/src/bin/m.rs
Normal file
139
candle-wasm-examples/phi/src/bin/m.rs
Normal file
@ -0,0 +1,139 @@
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
use candle_wasm_example_phi::console_log;
|
||||
use js_sys::Date;
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
enum SelectedModel {
|
||||
MixFormer(MixFormer),
|
||||
Quantized(QMixFormer),
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct Model {
|
||||
model: SelectedModel,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
tokens: Vec<u32>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl Model {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, quantized: bool) -> Result<Model, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let config: Config = Config::v1_5();
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let start = Date::now();
|
||||
let model = if quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let model = QMixFormer::new(&config, vb)?;
|
||||
SelectedModel::Quantized(model)
|
||||
} else {
|
||||
let device = &Device::Cpu;
|
||||
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
|
||||
let model = MixFormer::new(&config, vb)?;
|
||||
SelectedModel::MixFormer(model)
|
||||
};
|
||||
console_log!("model loaded in {:?}s", (Date::now() - start) / 1000.);
|
||||
let logits_processor = LogitsProcessor::new(299792458, None, None);
|
||||
Ok(Self {
|
||||
model,
|
||||
tokenizer,
|
||||
tokens: vec![],
|
||||
logits_processor,
|
||||
repeat_penalty: 1.,
|
||||
repeat_last_n: 64,
|
||||
})
|
||||
}
|
||||
#[wasm_bindgen]
|
||||
pub fn init_with_prompt(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
temp: f64,
|
||||
top_p: f64,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
seed: u64,
|
||||
) -> Result<String, JsError> {
|
||||
match &mut self.model {
|
||||
SelectedModel::MixFormer(m) => m.clear_kv_cache(),
|
||||
SelectedModel::Quantized(m) => m.clear_kv_cache(),
|
||||
};
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
let top_p = if top_p <= 0. || top_p >= 1. {
|
||||
None
|
||||
} else {
|
||||
Some(top_p)
|
||||
};
|
||||
self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
self.repeat_penalty = repeat_penalty;
|
||||
self.repeat_last_n = repeat_last_n;
|
||||
self.tokens.clear();
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(|m| JsError::new(&m.to_string()))?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let text = self
|
||||
.process(&tokens)
|
||||
.map_err(|m| JsError::new(&m.to_string()))?;
|
||||
Ok(text)
|
||||
}
|
||||
#[wasm_bindgen]
|
||||
pub fn next_token(&mut self) -> Result<String, JsError> {
|
||||
let last_token = *self.tokens.last().unwrap();
|
||||
let text = self
|
||||
.process(&[last_token])
|
||||
.map_err(|m| JsError::new(&m.to_string()))?;
|
||||
Ok(text)
|
||||
}
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
|
||||
let dev = Device::Cpu;
|
||||
let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
SelectedModel::MixFormer(m) => m.forward(&input)?,
|
||||
SelectedModel::Quantized(m) => m.forward(&input)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
self.tokens.push(next_token);
|
||||
let token = match self.tokenizer.decode(&[next_token], false) {
|
||||
Ok(token) => token,
|
||||
Err(e) => {
|
||||
console_log!("error decoding token: {:?}", e);
|
||||
"".to_string()
|
||||
}
|
||||
};
|
||||
// console_log!("token: {:?}: {:?}", token, next_token);
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
Reference in New Issue
Block a user