Add fine-tuned text classifier to xlm roberta example (#2969)

This commit is contained in:
Jon Eskin
2025-05-28 00:17:07 -04:00
committed by GitHub
parent cac51fe16a
commit 1a183c988a
2 changed files with 62 additions and 0 deletions

View File

@ -28,3 +28,26 @@ Ranking Results:
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
--------------------------------------------------------------------------------
```
Text-Classification:
```bash
cargo run --example xlm-roberta -- --task text-classification --model xlmr-formality-classifier
```
```markdown
Formality Scores:
Text 1: "I like you. I love you"
formal: 0.9933
informal: 0.0067
Text 2: "Hey, what's up?"
formal: 0.8812
informal: 0.1188
Text 3: "Siema, co porabiasz?"
formal: 0.9358
informal: 0.0642
Text 4: "I feel deep regret and sadness about the situation in international politics."
formal: 0.9987
informal: 0.0013
```

View File

@ -2,6 +2,7 @@ use std::path::PathBuf;
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::ops::softmax;
use candle_nn::VarBuilder;
use candle_transformers::models::xlm_roberta::{
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
@ -17,12 +18,14 @@ enum Model {
BgeRerankerBaseV2,
XLMRobertaBase,
XLMRobertaLarge,
XLMRFormalityClassifier,
}
#[derive(Debug, Clone, ValueEnum)]
enum Task {
FillMask,
Reranker,
TextClassification,
}
#[derive(Parser, Debug)]
@ -83,6 +86,12 @@ fn main() -> Result<()> {
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
},
Task::TextClassification => match args.model {
Model::XLMRFormalityClassifier => "s-nlp/xlmr_formality_classifier".to_string(),
_ => anyhow::bail!(
"XLM-RoBERTa models are not supported for text classification task"
),
},
},
};
let repo = api.repo(Repo::with_revision(
@ -217,6 +226,36 @@ fn main() -> Result<()> {
});
println!("{:-<80}", "");
}
Task::TextClassification => {
let sentences = vec![
"I like you. I love you".to_string(),
"Hey, what's up?".to_string(),
"Siema, co porabiasz?".to_string(),
"I feel deep regret and sadness about the situation in international politics."
.to_string(),
];
let model = XLMRobertaForSequenceClassification::new(2, &config, vb)?;
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
let attention_mask =
get_attention_mask(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
let logits = model
.forward(&input_ids, &attention_mask, &token_type_ids)?
.to_dtype(candle::DType::F32)?;
let probabilities = softmax(&logits, 1)?;
let probs_vec = probabilities.to_vec2::<f32>()?;
println!("Formality Scores:");
for (i, (text, probs)) in sentences.iter().zip(probs_vec.iter()).enumerate() {
println!("Text {}: \"{}\"", i + 1, text);
println!(" formal: {:.4}", probs[0]);
println!(" informal: {:.4}", probs[1]);
println!();
}
}
}
Ok(())
}