mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
202 lines
6.8 KiB
Rust
202 lines
6.8 KiB
Rust
#[cfg(feature = "mkl")]
|
|
extern crate intel_mkl_src;
|
|
|
|
#[cfg(feature = "accelerate")]
|
|
extern crate accelerate_src;
|
|
|
|
use candle_transformers::models::jina_bert::{BertModel, Config, PositionEmbeddingType};
|
|
|
|
use anyhow::Error as E;
|
|
use candle::{DType, Module, Tensor};
|
|
use candle_nn::VarBuilder;
|
|
use clap::Parser;
|
|
|
|
#[derive(Parser, Debug)]
|
|
#[command(author, version, about, long_about = None)]
|
|
struct Args {
|
|
/// Run on CPU rather than on GPU.
|
|
#[arg(long)]
|
|
cpu: bool,
|
|
|
|
/// Enable tracing (generates a trace-timestamp.json file).
|
|
#[arg(long)]
|
|
tracing: bool,
|
|
|
|
/// When set, compute embeddings for this prompt.
|
|
#[arg(long)]
|
|
prompt: Option<String>,
|
|
|
|
/// The number of times to run the prompt.
|
|
#[arg(long, default_value = "1")]
|
|
n: usize,
|
|
|
|
/// L2 normalization for embeddings.
|
|
#[arg(long, default_value = "true")]
|
|
normalize_embeddings: bool,
|
|
|
|
#[arg(long)]
|
|
tokenizer: Option<String>,
|
|
|
|
#[arg(long)]
|
|
model: Option<String>,
|
|
|
|
#[arg(long)]
|
|
model_file: Option<String>,
|
|
}
|
|
|
|
impl Args {
|
|
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
let model_name = match self.model.as_ref() {
|
|
Some(model) => model.to_string(),
|
|
None => "jinaai/jina-embeddings-v2-base-en".to_string(),
|
|
};
|
|
|
|
let model = match &self.model_file {
|
|
Some(model_file) => std::path::PathBuf::from(model_file),
|
|
None => Api::new()?
|
|
.repo(Repo::new(model_name.to_string(), RepoType::Model))
|
|
.get("model.safetensors")?,
|
|
};
|
|
let tokenizer = match &self.tokenizer {
|
|
Some(file) => std::path::PathBuf::from(file),
|
|
None => Api::new()?
|
|
.repo(Repo::new(model_name.to_string(), RepoType::Model))
|
|
.get("tokenizer.json")?,
|
|
};
|
|
let device = candle_examples::device(self.cpu)?;
|
|
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
|
let config = Config::new(
|
|
tokenizer.get_vocab_size(true),
|
|
768,
|
|
12,
|
|
12,
|
|
3072,
|
|
candle_nn::Activation::Gelu,
|
|
8192,
|
|
2,
|
|
0.02,
|
|
1e-12,
|
|
0,
|
|
PositionEmbeddingType::Alibi,
|
|
);
|
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
|
let model = BertModel::new(vb, &config)?;
|
|
Ok((model, tokenizer))
|
|
}
|
|
}
|
|
|
|
fn main() -> anyhow::Result<()> {
|
|
use tracing_chrome::ChromeLayerBuilder;
|
|
use tracing_subscriber::prelude::*;
|
|
|
|
let args = Args::parse();
|
|
let _guard = if args.tracing {
|
|
println!("tracing...");
|
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
tracing_subscriber::registry().with(chrome_layer).init();
|
|
Some(guard)
|
|
} else {
|
|
None
|
|
};
|
|
let start = std::time::Instant::now();
|
|
|
|
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
|
let device = &model.device;
|
|
|
|
if let Some(prompt) = args.prompt {
|
|
let tokenizer = tokenizer
|
|
.with_padding(None)
|
|
.with_truncation(None)
|
|
.map_err(E::msg)?;
|
|
let tokens = tokenizer
|
|
.encode(prompt, true)
|
|
.map_err(E::msg)?
|
|
.get_ids()
|
|
.to_vec();
|
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
|
println!("Loaded and encoded {:?}", start.elapsed());
|
|
let start = std::time::Instant::now();
|
|
let embeddings = model.forward(&token_ids)?;
|
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
|
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
|
println!("pooled_embeddigns: {embeddings}");
|
|
let embeddings = if args.normalize_embeddings {
|
|
normalize_l2(&embeddings)?
|
|
} else {
|
|
embeddings
|
|
};
|
|
if args.normalize_embeddings {
|
|
println!("normalized_embeddings: {embeddings}");
|
|
}
|
|
println!("Took {:?}", start.elapsed());
|
|
} else {
|
|
let sentences = [
|
|
"The cat sits outside",
|
|
"A man is playing guitar",
|
|
"I love pasta",
|
|
"The new movie is awesome",
|
|
"The cat plays in the garden",
|
|
"A woman watches TV",
|
|
"The new movie is so great",
|
|
"Do you like pizza?",
|
|
];
|
|
let n_sentences = sentences.len();
|
|
if let Some(pp) = tokenizer.get_padding_mut() {
|
|
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
|
} else {
|
|
let pp = tokenizers::PaddingParams {
|
|
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
|
..Default::default()
|
|
};
|
|
tokenizer.with_padding(Some(pp));
|
|
}
|
|
let tokens = tokenizer
|
|
.encode_batch(sentences.to_vec(), true)
|
|
.map_err(E::msg)?;
|
|
let token_ids = tokens
|
|
.iter()
|
|
.map(|tokens| {
|
|
let tokens = tokens.get_ids().to_vec();
|
|
Tensor::new(tokens.as_slice(), device)
|
|
})
|
|
.collect::<candle::Result<Vec<_>>>()?;
|
|
|
|
let token_ids = Tensor::stack(&token_ids, 0)?;
|
|
println!("running inference on batch {:?}", token_ids.shape());
|
|
let embeddings = model.forward(&token_ids)?;
|
|
println!("generated embeddings {:?}", embeddings.shape());
|
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
|
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
|
let embeddings = if args.normalize_embeddings {
|
|
normalize_l2(&embeddings)?
|
|
} else {
|
|
embeddings
|
|
};
|
|
println!("pooled embeddings {:?}", embeddings.shape());
|
|
|
|
let mut similarities = vec![];
|
|
for i in 0..n_sentences {
|
|
let e_i = embeddings.get(i)?;
|
|
for j in (i + 1)..n_sentences {
|
|
let e_j = embeddings.get(j)?;
|
|
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
|
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
|
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
|
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
|
similarities.push((cosine_similarity, i, j))
|
|
}
|
|
}
|
|
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
|
for &(score, i, j) in similarities[..5].iter() {
|
|
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
|
|
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
|
}
|