#[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; use std::fmt::Display; use std::path::PathBuf; use anyhow::bail; use anyhow::{Error as E, Result}; use candle::{Device, Tensor}; use candle_nn::ops::softmax; use candle_nn::VarBuilder; use candle_transformers::models::debertav2::{Config as DebertaV2Config, DebertaV2NERModel}; use candle_transformers::models::debertav2::{DebertaV2SeqClassificationModel, Id2Label}; use candle_transformers::models::debertav2::{NERItem, TextClassificationItem}; use clap::{ArgGroup, Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::{Encoding, PaddingParams, Tokenizer}; enum TaskType { Ner(DebertaV2NERModel), TextClassification(DebertaV2SeqClassificationModel), } #[derive(Parser, Debug, Clone, ValueEnum)] enum ArgsTask { /// Named Entity Recognition Ner, /// Text Classification TextClassification, } impl Display for ArgsTask { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { ArgsTask::Ner => write!(f, "ner"), ArgsTask::TextClassification => write!(f, "text-classification"), } } } #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] #[command(group(ArgGroup::new("model") .required(true) .args(&["model_id", "model_path"])))] struct Args { /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, /// The model id to use from HuggingFace #[arg(long, requires_if("model_id", "revision"))] model_id: Option, /// Revision of the model to use (default: "main") #[arg(long, default_value = "main")] revision: String, /// Specify a sentence to inference. Specify multiple times to inference multiple sentences. #[arg(long = "sentence", name="sentences", num_args = 1..)] sentences: Vec, /// Use the pytorch weights rather than the by-default safetensors #[arg(long)] use_pth: bool, /// Perform a very basic benchmark on inferencing, using N number of iterations #[arg(long)] benchmark_iters: Option, /// Which task to run #[arg(long, default_value_t = ArgsTask::Ner)] task: ArgsTask, /// Use model from a specific directory instead of HuggingFace local cache. /// Using this ignores model_id and revision args. #[arg(long)] model_path: Option, /// Pass in an Id2Label if the model config does not provide it, in JSON format. Example: --id2label='{"0": "True", "1": "False"}' #[arg(long)] id2label: Option, } impl Args { fn build_model_and_tokenizer( &self, ) -> Result<(TaskType, DebertaV2Config, Tokenizer, Id2Label)> { let device = candle_examples::device(self.cpu)?; // Get files from either the HuggingFace API, or from a specified local directory. let (config_filename, tokenizer_filename, weights_filename) = { match &self.model_path { Some(base_path) => { if !base_path.is_dir() { bail!("Model path {} is not a directory.", base_path.display()) } let config = base_path.join("config.json"); let tokenizer = base_path.join("tokenizer.json"); let weights = if self.use_pth { base_path.join("pytorch_model.bin") } else { base_path.join("model.safetensors") }; (config, tokenizer, weights) } None => { let repo = Repo::with_revision( self.model_id.as_ref().unwrap().clone(), RepoType::Model, self.revision.clone(), ); let api = Api::new()?; let api = api.repo(repo); let config = api.get("config.json")?; let tokenizer = api.get("tokenizer.json")?; let weights = if self.use_pth { api.get("pytorch_model.bin")? } else { api.get("model.safetensors")? }; (config, tokenizer, weights) } } }; let config = std::fs::read_to_string(config_filename)?; let config: DebertaV2Config = serde_json::from_str(&config)?; // Command-line id2label takes precedence. Otherwise, use model config's id2label. // If neither is specified, then we can't proceed. let id2label = if let Some(id2labelstr) = &self.id2label { serde_json::from_str(id2labelstr.as_str())? } else if let Some(id2label) = &config.id2label { id2label.clone() } else { bail!("Id2Label not found in the model configuration nor specified as a parameter") }; let mut tokenizer = Tokenizer::from_file(tokenizer_filename) .map_err(|e| candle::Error::Msg(format!("Tokenizer error: {e}")))?; tokenizer.with_padding(Some(PaddingParams::default())); let vb = if self.use_pth { VarBuilder::from_pth( &weights_filename, candle_transformers::models::debertav2::DTYPE, &device, )? } else { unsafe { VarBuilder::from_mmaped_safetensors( &[weights_filename], candle_transformers::models::debertav2::DTYPE, &device, )? } }; let vb = vb.set_prefix("deberta"); match self.task { ArgsTask::Ner => Ok(( TaskType::Ner(DebertaV2NERModel::load( vb, &config, Some(id2label.clone()), )?), config, tokenizer, id2label, )), ArgsTask::TextClassification => Ok(( TaskType::TextClassification(DebertaV2SeqClassificationModel::load( vb, &config, Some(id2label.clone()), )?), config, tokenizer, id2label, )), } } } fn get_device(model_type: &TaskType) -> &Device { match model_type { TaskType::Ner(ner_model) => &ner_model.device, TaskType::TextClassification(classification_model) => &classification_model.device, } } struct ModelInput { encoding: Vec, input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor, } fn main() -> Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; let args = Args::parse(); let _guard = if args.tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); Some(guard) } else { None }; let model_load_time = std::time::Instant::now(); let (task_type, _model_config, tokenizer, id2label) = args.build_model_and_tokenizer()?; println!( "Loaded model and tokenizers in {:?}", model_load_time.elapsed() ); let device = get_device(&task_type); let tokenize_time = std::time::Instant::now(); let model_input: ModelInput = { let tokenizer_encodings = tokenizer .encode_batch(args.sentences, true) .map_err(E::msg)?; let mut encoding_stack: Vec = Vec::default(); let mut attention_mask_stack: Vec = Vec::default(); let mut token_type_id_stack: Vec = Vec::default(); for encoding in &tokenizer_encodings { encoding_stack.push(Tensor::new(encoding.get_ids(), device)?); attention_mask_stack.push(Tensor::new(encoding.get_attention_mask(), device)?); token_type_id_stack.push(Tensor::new(encoding.get_type_ids(), device)?); } ModelInput { encoding: tokenizer_encodings, input_ids: Tensor::stack(&encoding_stack[..], 0)?, attention_mask: Tensor::stack(&attention_mask_stack[..], 0)?, token_type_ids: Tensor::stack(&token_type_id_stack[..], 0)?, } }; println!( "Tokenized and loaded inputs in {:?}", tokenize_time.elapsed() ); match task_type { TaskType::Ner(ner_model) => { if let Some(num_iters) = args.benchmark_iters { create_benchmark(num_iters, model_input)( |input_ids, token_type_ids, attention_mask| { ner_model.forward(input_ids, Some(token_type_ids), Some(attention_mask))?; Ok(()) }, )?; std::process::exit(0); } let inference_time = std::time::Instant::now(); let logits = ner_model.forward( &model_input.input_ids, Some(model_input.token_type_ids), Some(model_input.attention_mask), )?; println!("Inferenced inputs in {:?}", inference_time.elapsed()); let max_scores_vec = softmax(&logits, 2)?.max(2)?.to_vec2::()?; let max_indices_vec: Vec> = logits.argmax(2)?.to_vec2()?; let input_ids = model_input.input_ids.to_vec2::()?; let mut results: Vec> = Default::default(); for (input_row_idx, input_id_row) in input_ids.iter().enumerate() { let mut current_row_result: Vec = Default::default(); let current_row_encoding = model_input.encoding.get(input_row_idx).unwrap(); let current_row_tokens = current_row_encoding.get_tokens(); let current_row_max_scores = max_scores_vec.get(input_row_idx).unwrap(); for (input_id_idx, _input_id) in input_id_row.iter().enumerate() { // Do not include special characters in output if current_row_encoding.get_special_tokens_mask()[input_id_idx] == 1 { continue; } let max_label_idx = max_indices_vec .get(input_row_idx) .unwrap() .get(input_id_idx) .unwrap(); let label = id2label.get(max_label_idx).unwrap().clone(); // Do not include those labeled as "O" ("Other") if label == "O" { continue; } current_row_result.push(NERItem { entity: label, word: current_row_tokens[input_id_idx].clone(), score: current_row_max_scores[input_id_idx], start: current_row_encoding.get_offsets()[input_id_idx].0, end: current_row_encoding.get_offsets()[input_id_idx].1, index: input_id_idx, }); } results.push(current_row_result); } println!("\n{:?}", results); } TaskType::TextClassification(classification_model) => { let inference_time = std::time::Instant::now(); let logits = classification_model.forward( &model_input.input_ids, Some(model_input.token_type_ids), Some(model_input.attention_mask), )?; println!("Inferenced inputs in {:?}", inference_time.elapsed()); let predictions = logits.argmax(1)?.to_vec1::()?; let scores = softmax(&logits, 1)?.max(1)?.to_vec1::()?; let mut results = Vec::::default(); for (idx, prediction) in predictions.iter().enumerate() { results.push(TextClassificationItem { label: id2label[prediction].clone(), score: scores[idx], }); } println!("\n{:?}", results); } } Ok(()) } fn create_benchmark( num_iters: usize, model_input: ModelInput, ) -> impl Fn(F) -> Result<(), candle::Error> where F: Fn(&Tensor, Tensor, Tensor) -> Result<(), candle::Error>, { move |code: F| -> Result<(), candle::Error> { println!("Running {num_iters} iterations..."); let mut durations = Vec::with_capacity(num_iters); for _ in 0..num_iters { let token_type_ids = model_input.token_type_ids.clone(); let attention_mask = model_input.attention_mask.clone(); let start = std::time::Instant::now(); code(&model_input.input_ids, token_type_ids, attention_mask)?; let duration = start.elapsed(); durations.push(duration.as_nanos()); } let min_time = *durations.iter().min().unwrap(); let max_time = *durations.iter().max().unwrap(); let avg_time = durations.iter().sum::() as f64 / num_iters as f64; println!("Min time: {:.3} ms", min_time as f64 / 1_000_000.0); println!("Avg time: {:.3} ms", avg_time / 1_000_000.0); println!("Max time: {:.3} ms", max_time as f64 / 1_000_000.0); Ok(()) } }