Phi 2 wasm (#1432)

* add phi 2.0 quantized model wasm

* cols

* spell

* bug
This commit is contained in:
Radamés Ajna
2023-12-14 04:04:17 -08:00
committed by GitHub
parent 5e33c85c8f
commit 104e196d46
3 changed files with 102 additions and 26 deletions

View File

@ -5,6 +5,7 @@ use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausa
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle_wasm_example_phi::console_log;
use js_sys::Date;
use serde::Deserialize;
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
@ -23,6 +24,12 @@ pub struct Model {
repeat_last_n: usize,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct ModelName {
pub _name_or_path: String,
}
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
@ -34,15 +41,25 @@ impl Model {
) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let name: ModelName = serde_json::from_slice(&config)?;
let config: Config = serde_json::from_slice(&config)?;
console_log!("config loaded {:?}", name);
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let start = Date::now();
console_log!("weights len: {:?}", weights.len());
let model = if quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
let model = QMixFormer::new(&config, vb)?;
SelectedModel::Quantized(model)
console_log!("weights loaded");
if name._name_or_path == "microsoft/phi-2" {
let model = QMixFormer::new_v2(&config, vb)?;
SelectedModel::Quantized(model)
} else {
let model = QMixFormer::new(&config, vb)?;
SelectedModel::Quantized(model)
}
} else {
let device = &Device::Cpu;
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;