mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
BERT Wasm (#902)
* implement wasm module * add example to workspace * add UI explore semantic similiarity * change status messages * formatting * minor changes
This commit is contained in:
92
candle-wasm-examples/bert/src/bin/m.rs
Normal file
92
candle-wasm-examples/bert/src/bin/m.rs
Normal file
@ -0,0 +1,92 @@
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::bert::{BertModel, Config};
|
||||
use candle_wasm_example_bert::console_log;
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct Model {
|
||||
bert: BertModel,
|
||||
tokenizer: Tokenizer,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl Model {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, config: Vec<u8>) -> Result<Model, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let device = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F64, device);
|
||||
let config: Config = serde_json::from_slice(&config)?;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let bert = BertModel::load(vb, &config)?;
|
||||
|
||||
Ok(Self { bert, tokenizer })
|
||||
}
|
||||
|
||||
pub fn get_embeddings(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let input: Params =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let sentences = input.sentences;
|
||||
let normalize_embeddings = input.normalize_embeddings;
|
||||
|
||||
let device = &Device::Cpu;
|
||||
if let Some(pp) = self.tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
} else {
|
||||
let pp = PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
..Default::default()
|
||||
};
|
||||
self.tokenizer.with_padding(Some(pp));
|
||||
}
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.encode_batch(sentences.to_vec(), true)
|
||||
.map_err(|m| JsError::new(&m.to_string()))?;
|
||||
|
||||
let token_ids: Vec<Tensor> = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
console_log!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
|
||||
console_log!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if normalize_embeddings {
|
||||
embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
let embeddings_data = embeddings.to_vec2()?;
|
||||
Ok(serde_wasm_bindgen::to_value(&Embeddings {
|
||||
data: embeddings_data,
|
||||
})?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct Embeddings {
|
||||
data: Vec<Vec<f64>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
pub struct Params {
|
||||
sentences: Vec<String>,
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
fn main() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
20
candle-wasm-examples/bert/src/lib.rs
Normal file
20
candle-wasm-examples/bert/src/lib.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use candle_transformers::models::bert;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
pub use bert::{BertModel, Config, DTYPE};
|
||||
pub use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
#[wasm_bindgen]
|
||||
extern "C" {
|
||||
// Use `js_namespace` here to bind `console.log(..)` instead of just
|
||||
// `log(..)`
|
||||
#[wasm_bindgen(js_namespace = console)]
|
||||
pub fn log(s: &str);
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! console_log {
|
||||
// Note that this is using the `log` function imported above during
|
||||
// `bare_bones`
|
||||
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
|
||||
}
|
Reference in New Issue
Block a user