mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Add fine-tuned text classifier to xlm roberta example (#2969)
This commit is contained in:
@ -28,3 +28,26 @@ Ranking Results:
|
||||
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
|
||||
--------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
Text-Classification:
|
||||
```bash
|
||||
cargo run --example xlm-roberta -- --task text-classification --model xlmr-formality-classifier
|
||||
```
|
||||
```markdown
|
||||
Formality Scores:
|
||||
Text 1: "I like you. I love you"
|
||||
formal: 0.9933
|
||||
informal: 0.0067
|
||||
|
||||
Text 2: "Hey, what's up?"
|
||||
formal: 0.8812
|
||||
informal: 0.1188
|
||||
|
||||
Text 3: "Siema, co porabiasz?"
|
||||
formal: 0.9358
|
||||
informal: 0.0642
|
||||
|
||||
Text 4: "I feel deep regret and sadness about the situation in international politics."
|
||||
formal: 0.9987
|
||||
informal: 0.0013
|
||||
```
|
@ -2,6 +2,7 @@ use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::ops::softmax;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::xlm_roberta::{
|
||||
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
|
||||
@ -17,12 +18,14 @@ enum Model {
|
||||
BgeRerankerBaseV2,
|
||||
XLMRobertaBase,
|
||||
XLMRobertaLarge,
|
||||
XLMRFormalityClassifier,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
enum Task {
|
||||
FillMask,
|
||||
Reranker,
|
||||
TextClassification,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -83,6 +86,12 @@ fn main() -> Result<()> {
|
||||
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
|
||||
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
|
||||
},
|
||||
Task::TextClassification => match args.model {
|
||||
Model::XLMRFormalityClassifier => "s-nlp/xlmr_formality_classifier".to_string(),
|
||||
_ => anyhow::bail!(
|
||||
"XLM-RoBERTa models are not supported for text classification task"
|
||||
),
|
||||
},
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
@ -217,6 +226,36 @@ fn main() -> Result<()> {
|
||||
});
|
||||
println!("{:-<80}", "");
|
||||
}
|
||||
Task::TextClassification => {
|
||||
let sentences = vec![
|
||||
"I like you. I love you".to_string(),
|
||||
"Hey, what's up?".to_string(),
|
||||
"Siema, co porabiasz?".to_string(),
|
||||
"I feel deep regret and sadness about the situation in international politics."
|
||||
.to_string(),
|
||||
];
|
||||
let model = XLMRobertaForSequenceClassification::new(2, &config, vb)?;
|
||||
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
|
||||
|
||||
let attention_mask =
|
||||
get_attention_mask(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
|
||||
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
|
||||
|
||||
let logits = model
|
||||
.forward(&input_ids, &attention_mask, &token_type_ids)?
|
||||
.to_dtype(candle::DType::F32)?;
|
||||
|
||||
let probabilities = softmax(&logits, 1)?;
|
||||
let probs_vec = probabilities.to_vec2::<f32>()?;
|
||||
|
||||
println!("Formality Scores:");
|
||||
for (i, (text, probs)) in sentences.iter().zip(probs_vec.iter()).enumerate() {
|
||||
println!("Text {}: \"{}\"", i + 1, text);
|
||||
println!(" formal: {:.4}", probs[0]);
|
||||
println!(" informal: {:.4}", probs[1]);
|
||||
println!();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user