diff --git a/candle-examples/examples/jina-bert/main.rs b/candle-examples/examples/jina-bert/main.rs index d959d4cb..04b0c2d5 100644 --- a/candle-examples/examples/jina-bert/main.rs +++ b/candle-examples/examples/jina-bert/main.rs @@ -4,7 +4,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle_transformers::models::jina_bert::{BertModel, Config}; +use candle_transformers::models::jina_bert::{BertModel, Config, PositionEmbeddingType}; use anyhow::Error as E; use candle::{DType, Module, Tensor}; @@ -39,16 +39,25 @@ struct Args { #[arg(long)] model: Option, + + #[arg(long)] + model_file: Option, } impl Args { fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> { use hf_hub::{api::sync::Api, Repo, RepoType}; - let model = match &self.model { + let default = "jinaai/jina-embeddings-v2-base-en".to_string(); + let model_name = match &self.model { + Some(model) => model, + None => &default, + }; + + let model = match &self.model_file { Some(model_file) => std::path::PathBuf::from(model_file), None => Api::new()? .repo(Repo::new( - "jinaai/jina-embeddings-v2-base-en".to_string(), + model_name.to_string(), RepoType::Model, )) .get("model.safetensors")?, @@ -57,14 +66,14 @@ impl Args { Some(file) => std::path::PathBuf::from(file), None => Api::new()? .repo(Repo::new( - "sentence-transformers/all-MiniLM-L6-v2".to_string(), + model_name.to_string(), RepoType::Model, )) .get("tokenizer.json")?, }; let device = candle_examples::device(self.cpu)?; - let config = Config::v2_base(); let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?; + let config = Config::new(tokenizer.get_vocab_size(true), 768, 12, 12, 3072, candle_nn::Activation::Gelu, 8192, 2, 0.02, 1e-12, 0, PositionEmbeddingType::Alibi); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }; let model = BertModel::new(vb, &config)?; Ok((model, tokenizer)) @@ -101,14 +110,21 @@ fn main() -> anyhow::Result<()> { .to_vec(); let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; println!("Loaded and encoded {:?}", start.elapsed()); - for idx in 0..args.n { - let start = std::time::Instant::now(); - let ys = model.forward(&token_ids)?; - if idx == 0 { - println!("{ys}"); - } - println!("Took {:?}", start.elapsed()); + let start = std::time::Instant::now(); + let embeddings = model.forward(&token_ids)?; + let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; + let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; + println!("pooled_embeddigns: {embeddings}"); + let embeddings = if args.normalize_embeddings { + normalize_l2(&embeddings)? + } else { + embeddings + }; + if args.normalize_embeddings { + println!("normalized_embeddings: {embeddings}"); } + println!("Took {:?}", start.elapsed()); + } else { let sentences = [ "The cat sits outside", diff --git a/candle-transformers/src/models/jina_bert.rs b/candle-transformers/src/models/jina_bert.rs index 7e3c3887..97bc1b25 100644 --- a/candle-transformers/src/models/jina_bert.rs +++ b/candle-transformers/src/models/jina_bert.rs @@ -47,6 +47,36 @@ impl Config { position_embedding_type: PositionEmbeddingType::Alibi, } } + + pub fn new( + vocab_size: usize, + hidden_size: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + intermediate_size: usize, + hidden_act: candle_nn::Activation, + max_position_embeddings: usize, + type_vocab_size: usize, + initializer_range: f64, + layer_norm_eps: f64, + pad_token_id: usize, + position_embedding_type: PositionEmbeddingType, + ) -> Self { + Config { + vocab_size, + hidden_size, + num_hidden_layers, + num_attention_heads, + intermediate_size, + hidden_act, + max_position_embeddings, + type_vocab_size, + initializer_range, + layer_norm_eps, + pad_token_id, + position_embedding_type, + } + } } #[derive(Clone, Debug)]