use std::path::PathBuf; use anyhow::{Error as E, Result}; use candle::{Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::models::xlm_roberta::{ Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification, }; use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::{PaddingParams, Tokenizer}; #[derive(Debug, Clone, ValueEnum)] enum Model { BgeRerankerBase, BgeRerankerLarge, BgeRerankerBaseV2, XLMRobertaBase, XLMRobertaLarge, } #[derive(Debug, Clone, ValueEnum)] enum Task { FillMask, Reranker, } #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending #[arg(long)] model_id: Option, #[arg(long, default_value = "main")] revision: String, #[arg(long, default_value = "bge-reranker-base")] model: Model, #[arg(long, default_value = "reranker")] task: Task, // Path to the tokenizer file. #[arg(long)] tokenizer_file: Option, // Path to the weight files. #[arg(long)] weight_files: Option, // Path to the config file. #[arg(long)] config_file: Option, /// When set, compute embeddings for this prompt. #[arg(long)] prompt: Option, } fn main() -> Result<()> { let args = Args::parse(); let api = Api::new()?; let model_id = match &args.model_id { Some(model_id) => model_id.to_string(), None => match args.task { Task::FillMask => match args.model { Model::XLMRobertaBase => "FacebookAI/xlm-roberta-base".to_string(), Model::XLMRobertaLarge => "FacebookAI/xlm-roberta-large".to_string(), _ => anyhow::bail!("BGE models are not supported for fill-mask task"), }, Task::Reranker => match args.model { Model::BgeRerankerBase => "BAAI/bge-reranker-base".to_string(), Model::BgeRerankerLarge => "BAAI/bge-reranker-large".to_string(), Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(), _ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"), }, }, }; let repo = api.repo(Repo::with_revision( model_id, RepoType::Model, args.revision, )); let tokenizer_filename = match args.tokenizer_file { Some(file) => std::path::PathBuf::from(file), None => repo.get("tokenizer.json")?, }; let config_filename = match args.config_file { Some(file) => std::path::PathBuf::from(file), None => repo.get("config.json")?, }; let weights_filename = match args.weight_files { Some(files) => PathBuf::from(files), None => match repo.get("model.safetensors") { Ok(safetensors) => safetensors, Err(_) => match repo.get("pytorch_model.bin") { Ok(pytorch_model) => pytorch_model, Err(e) => { return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e))); } }, }, }; let config = std::fs::read_to_string(config_filename)?; let config: Config = serde_json::from_str(&config)?; let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let device = candle_examples::device(args.cpu)?; let vb = if weights_filename.ends_with("model.safetensors") { unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F16, &device) .unwrap() } } else { println!("Loading weights from pytorch_model.bin"); VarBuilder::from_pth(&weights_filename, candle::DType::F16, &device).unwrap() }; tokenizer .with_padding(Some(PaddingParams { strategy: tokenizers::PaddingStrategy::BatchLongest, pad_id: config.pad_token_id, ..Default::default() })) .with_truncation(None) .map_err(E::msg)?; match args.task { Task::FillMask => { let prompt = vec![ "Hello I'm a model.".to_string(), "I'm a boy.".to_string(), "I'm in berlin.".to_string(), ]; let model = XLMRobertaForMaskedLM::new(&config, vb)?; let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&prompt), &device)?; let attention_mask = get_attention_mask(&tokenizer, TokenizeInput::Single(&prompt), &device)?; let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?; let output = model .forward( &input_ids, &attention_mask, &token_type_ids, None, None, None, )? .to_dtype(candle::DType::F32)?; let max_outs = output.argmax(2)?; let max_out = max_outs.to_vec2::()?; let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect(); let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap(); for (i, sentence) in decoded.iter().enumerate() { println!("Sentence: {} : {}", i + 1, sentence); } } Task::Reranker => { let query = "what is panda?".to_string(); let documents = ["South Korea is a country in East Asia.".to_string(), "There are forests in the mountains.".to_string(), "Pandas look like bears.".to_string(), "There are some animals with black and white fur.".to_string(), "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.".to_string()]; // create pairs of query and documents let pairs = documents .iter() .map(|doc| (query.clone(), doc.clone())) .collect::>(); let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?; let attention_mask = get_attention_mask(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?; let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?; let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?; let output = model.forward(&input_ids, &attention_mask, &token_type_ids)?; let output = candle_nn::ops::sigmoid(&output)?.t().unwrap(); let ranks = output .arg_sort_last_dim(false)? .to_vec2::()? .into_iter() .flatten() .collect::>(); println!("\nRanking Results:"); println!("{:-<80}", ""); documents.iter().enumerate().for_each(|(idx, doc)| { let rank = ranks.iter().position(|&r| r == idx as u32).unwrap(); let score = output .get_on_dim(1, idx) .unwrap() .to_dtype(candle::DType::F32) .unwrap() .to_vec1::() .unwrap(); println!("Rank #{:<2} | Score: {:.4} | {}", rank + 1, score[0], doc); }); println!("{:-<80}", ""); } } Ok(()) } #[derive(Debug)] pub enum TokenizeInput<'a> { Single(&'a [String]), Pairs(&'a [(String, String)]), } pub fn tokenize_batch( tokenizer: &Tokenizer, input: TokenizeInput, device: &Device, ) -> anyhow::Result { let tokens = match input { TokenizeInput::Single(text_batch) => tokenizer .encode_batch(text_batch.to_vec(), true) .map_err(E::msg)?, TokenizeInput::Pairs(pairs) => tokenizer .encode_batch(pairs.to_vec(), true) .map_err(E::msg)?, }; let token_ids = tokens .iter() .map(|tokens| { let tokens = tokens.get_ids().to_vec(); Tensor::new(tokens.as_slice(), device) }) .collect::>>()?; Ok(Tensor::stack(&token_ids, 0)?) } pub fn get_attention_mask( tokenizer: &Tokenizer, input: TokenizeInput, device: &Device, ) -> anyhow::Result { let tokens = match input { TokenizeInput::Single(text_batch) => tokenizer .encode_batch(text_batch.to_vec(), true) .map_err(E::msg)?, TokenizeInput::Pairs(pairs) => tokenizer .encode_batch(pairs.to_vec(), true) .map_err(E::msg)?, }; let attention_mask = tokens .iter() .map(|tokens| { let tokens = tokens.get_attention_mask().to_vec(); Tensor::new(tokens.as_slice(), device) }) .collect::>>()?; Ok(Tensor::stack(&attention_mask, 0)?) }