From 1a183c988ac53fed01ff59390177c2043722a70d Mon Sep 17 00:00:00 2001 From: Jon Eskin Date: Wed, 28 May 2025 00:17:07 -0400 Subject: [PATCH] Add fine-tuned text classifier to xlm roberta example (#2969) --- .../examples/xlm-roberta/Readme.md | 23 +++++++++++ candle-examples/examples/xlm-roberta/main.rs | 39 +++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/candle-examples/examples/xlm-roberta/Readme.md b/candle-examples/examples/xlm-roberta/Readme.md index 496b14e3..e5445c40 100644 --- a/candle-examples/examples/xlm-roberta/Readme.md +++ b/candle-examples/examples/xlm-roberta/Readme.md @@ -28,3 +28,26 @@ Ranking Results: > Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China. -------------------------------------------------------------------------------- ``` + +Text-Classification: +```bash +cargo run --example xlm-roberta -- --task text-classification --model xlmr-formality-classifier +``` +```markdown +Formality Scores: +Text 1: "I like you. I love you" + formal: 0.9933 + informal: 0.0067 + +Text 2: "Hey, what's up?" + formal: 0.8812 + informal: 0.1188 + +Text 3: "Siema, co porabiasz?" + formal: 0.9358 + informal: 0.0642 + +Text 4: "I feel deep regret and sadness about the situation in international politics." + formal: 0.9987 + informal: 0.0013 +``` \ No newline at end of file diff --git a/candle-examples/examples/xlm-roberta/main.rs b/candle-examples/examples/xlm-roberta/main.rs index 47ab44b0..c1f75916 100644 --- a/candle-examples/examples/xlm-roberta/main.rs +++ b/candle-examples/examples/xlm-roberta/main.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use anyhow::{Error as E, Result}; use candle::{Device, Tensor}; +use candle_nn::ops::softmax; use candle_nn::VarBuilder; use candle_transformers::models::xlm_roberta::{ Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification, @@ -17,12 +18,14 @@ enum Model { BgeRerankerBaseV2, XLMRobertaBase, XLMRobertaLarge, + XLMRFormalityClassifier, } #[derive(Debug, Clone, ValueEnum)] enum Task { FillMask, Reranker, + TextClassification, } #[derive(Parser, Debug)] @@ -83,6 +86,12 @@ fn main() -> Result<()> { Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(), _ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"), }, + Task::TextClassification => match args.model { + Model::XLMRFormalityClassifier => "s-nlp/xlmr_formality_classifier".to_string(), + _ => anyhow::bail!( + "XLM-RoBERTa models are not supported for text classification task" + ), + }, }, }; let repo = api.repo(Repo::with_revision( @@ -217,6 +226,36 @@ fn main() -> Result<()> { }); println!("{:-<80}", ""); } + Task::TextClassification => { + let sentences = vec![ + "I like you. I love you".to_string(), + "Hey, what's up?".to_string(), + "Siema, co porabiasz?".to_string(), + "I feel deep regret and sadness about the situation in international politics." + .to_string(), + ]; + let model = XLMRobertaForSequenceClassification::new(2, &config, vb)?; + let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&sentences), &device)?; + + let attention_mask = + get_attention_mask(&tokenizer, TokenizeInput::Single(&sentences), &device)?; + let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?; + + let logits = model + .forward(&input_ids, &attention_mask, &token_type_ids)? + .to_dtype(candle::DType::F32)?; + + let probabilities = softmax(&logits, 1)?; + let probs_vec = probabilities.to_vec2::()?; + + println!("Formality Scores:"); + for (i, (text, probs)) in sentences.iter().zip(probs_vec.iter()).enumerate() { + println!("Text {}: \"{}\"", i + 1, text); + println!(" formal: {:.4}", probs[0]); + println!(" informal: {:.4}", probs[1]); + println!(); + } + } } Ok(()) }