BERT Wasm (#902)

* implement wasm module

* add example to workspace

* add UI explore semantic similiarity

* change status messages

* formatting

* minor changes
This commit is contained in:
Radamés Ajna
2023-09-19 13:31:37 -07:00
committed by GitHub
parent 8696f64bae
commit 7ad82b87e4
9 changed files with 719 additions and 4 deletions

View File

@ -0,0 +1,92 @@
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config};
use candle_wasm_example_bert::console_log;
use tokenizers::{PaddingParams, Tokenizer};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct Model {
bert: BertModel,
tokenizer: Tokenizer,
}
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, config: Vec<u8>) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let device = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F64, device);
let config: Config = serde_json::from_slice(&config)?;
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let bert = BertModel::load(vb, &config)?;
Ok(Self { bert, tokenizer })
}
pub fn get_embeddings(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let input: Params =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
let sentences = input.sentences;
let normalize_embeddings = input.normalize_embeddings;
let device = &Device::Cpu;
if let Some(pp) = self.tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
self.tokenizer.with_padding(Some(pp));
}
let tokens = self
.tokenizer
.encode_batch(sentences.to_vec(), true)
.map_err(|m| JsError::new(&m.to_string()))?;
let token_ids: Vec<Tensor> = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<Result<Vec<_>, _>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let token_type_ids = token_ids.zeros_like()?;
console_log!("running inference on batch {:?}", token_ids.shape());
let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
console_log!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if normalize_embeddings {
embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
} else {
embeddings
};
let embeddings_data = embeddings.to_vec2()?;
Ok(serde_wasm_bindgen::to_value(&Embeddings {
data: embeddings_data,
})?)
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct Embeddings {
data: Vec<Vec<f64>>,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct Params {
sentences: Vec<String>,
normalize_embeddings: bool,
}
fn main() {
console_error_panic_hook::set_once();
}

View File

@ -0,0 +1,20 @@
use candle_transformers::models::bert;
use wasm_bindgen::prelude::*;
pub use bert::{BertModel, Config, DTYPE};
pub use tokenizers::{PaddingParams, Tokenizer};
#[wasm_bindgen]
extern "C" {
// Use `js_namespace` here to bind `console.log(..)` instead of just
// `log(..)`
#[wasm_bindgen(js_namespace = console)]
pub fn log(s: &str);
}
#[macro_export]
macro_rules! console_log {
// Note that this is using the `log` function imported above during
// `bare_bones`
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
}