mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Jina Bert Example fix and more configuration (#2191)
* fix: fix jina bert example logic * feat: enable jina embeddings de * feat: allow more flexibility on Jina Bert
This commit is contained in:
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle_transformers::models::jina_bert::{BertModel, Config};
|
use candle_transformers::models::jina_bert::{BertModel, Config, PositionEmbeddingType};
|
||||||
|
|
||||||
use anyhow::Error as E;
|
use anyhow::Error as E;
|
||||||
use candle::{DType, Module, Tensor};
|
use candle::{DType, Module, Tensor};
|
||||||
@ -39,16 +39,25 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_file: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
let model = match &self.model {
|
let default = "jinaai/jina-embeddings-v2-base-en".to_string();
|
||||||
|
let model_name = match &self.model {
|
||||||
|
Some(model) => model,
|
||||||
|
None => &default,
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = match &self.model_file {
|
||||||
Some(model_file) => std::path::PathBuf::from(model_file),
|
Some(model_file) => std::path::PathBuf::from(model_file),
|
||||||
None => Api::new()?
|
None => Api::new()?
|
||||||
.repo(Repo::new(
|
.repo(Repo::new(
|
||||||
"jinaai/jina-embeddings-v2-base-en".to_string(),
|
model_name.to_string(),
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
))
|
))
|
||||||
.get("model.safetensors")?,
|
.get("model.safetensors")?,
|
||||||
@ -57,14 +66,14 @@ impl Args {
|
|||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => Api::new()?
|
None => Api::new()?
|
||||||
.repo(Repo::new(
|
.repo(Repo::new(
|
||||||
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
model_name.to_string(),
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
))
|
))
|
||||||
.get("tokenizer.json")?,
|
.get("tokenizer.json")?,
|
||||||
};
|
};
|
||||||
let device = candle_examples::device(self.cpu)?;
|
let device = candle_examples::device(self.cpu)?;
|
||||||
let config = Config::v2_base();
|
|
||||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
let config = Config::new(tokenizer.get_vocab_size(true), 768, 12, 12, 3072, candle_nn::Activation::Gelu, 8192, 2, 0.02, 1e-12, 0, PositionEmbeddingType::Alibi);
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
let model = BertModel::new(vb, &config)?;
|
let model = BertModel::new(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
@ -101,14 +110,21 @@ fn main() -> anyhow::Result<()> {
|
|||||||
.to_vec();
|
.to_vec();
|
||||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
println!("Loaded and encoded {:?}", start.elapsed());
|
println!("Loaded and encoded {:?}", start.elapsed());
|
||||||
for idx in 0..args.n {
|
let start = std::time::Instant::now();
|
||||||
let start = std::time::Instant::now();
|
let embeddings = model.forward(&token_ids)?;
|
||||||
let ys = model.forward(&token_ids)?;
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||||
if idx == 0 {
|
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||||
println!("{ys}");
|
println!("pooled_embeddigns: {embeddings}");
|
||||||
}
|
let embeddings = if args.normalize_embeddings {
|
||||||
println!("Took {:?}", start.elapsed());
|
normalize_l2(&embeddings)?
|
||||||
|
} else {
|
||||||
|
embeddings
|
||||||
|
};
|
||||||
|
if args.normalize_embeddings {
|
||||||
|
println!("normalized_embeddings: {embeddings}");
|
||||||
}
|
}
|
||||||
|
println!("Took {:?}", start.elapsed());
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
let sentences = [
|
let sentences = [
|
||||||
"The cat sits outside",
|
"The cat sits outside",
|
||||||
|
@ -47,6 +47,36 @@ impl Config {
|
|||||||
position_embedding_type: PositionEmbeddingType::Alibi,
|
position_embedding_type: PositionEmbeddingType::Alibi,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new(
|
||||||
|
vocab_size: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
num_hidden_layers: usize,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
intermediate_size: usize,
|
||||||
|
hidden_act: candle_nn::Activation,
|
||||||
|
max_position_embeddings: usize,
|
||||||
|
type_vocab_size: usize,
|
||||||
|
initializer_range: f64,
|
||||||
|
layer_norm_eps: f64,
|
||||||
|
pad_token_id: usize,
|
||||||
|
position_embedding_type: PositionEmbeddingType,
|
||||||
|
) -> Self {
|
||||||
|
Config {
|
||||||
|
vocab_size,
|
||||||
|
hidden_size,
|
||||||
|
num_hidden_layers,
|
||||||
|
num_attention_heads,
|
||||||
|
intermediate_size,
|
||||||
|
hidden_act,
|
||||||
|
max_position_embeddings,
|
||||||
|
type_vocab_size,
|
||||||
|
initializer_range,
|
||||||
|
layer_norm_eps,
|
||||||
|
pad_token_id,
|
||||||
|
position_embedding_type,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
Reference in New Issue
Block a user