mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Fixes all clippy warnings
This commit is contained in:
@ -20,14 +20,14 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::{Encoding, PaddingParams, Tokenizer};
|
||||
|
||||
enum TaskType {
|
||||
NER(DebertaV2NERModel),
|
||||
Ner(DebertaV2NERModel),
|
||||
TextClassification(DebertaV2SeqClassificationModel),
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone, ValueEnum)]
|
||||
enum ArgsTask {
|
||||
/// Named Entity Recognition
|
||||
NER,
|
||||
Ner,
|
||||
|
||||
/// Text Classification
|
||||
TextClassification,
|
||||
@ -36,7 +36,7 @@ enum ArgsTask {
|
||||
impl Display for ArgsTask {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
ArgsTask::NER => write!(f, "ner"),
|
||||
ArgsTask::Ner => write!(f, "ner"),
|
||||
ArgsTask::TextClassification => write!(f, "text-classification"),
|
||||
}
|
||||
}
|
||||
@ -77,7 +77,7 @@ struct Args {
|
||||
benchmark_iters: Option<usize>,
|
||||
|
||||
/// Which task to run
|
||||
#[arg(long, default_value_t = ArgsTask::NER)]
|
||||
#[arg(long, default_value_t = ArgsTask::Ner)]
|
||||
task: ArgsTask,
|
||||
|
||||
/// Use model from a specific directory instead of HuggingFace local cache.
|
||||
@ -142,7 +142,7 @@ impl Args {
|
||||
// Command-line id2label takes precedence. Otherwise, use model config's id2label.
|
||||
// If neither is specified, then we can't proceed.
|
||||
let id2label = if let Some(id2labelstr) = &self.id2label {
|
||||
serde_json::from_str(&&id2labelstr.as_str())?
|
||||
serde_json::from_str(id2labelstr.as_str())?
|
||||
} else if let Some(id2label) = &config.id2label {
|
||||
id2label.clone()
|
||||
} else {
|
||||
@ -174,8 +174,8 @@ impl Args {
|
||||
let vb = vb.set_prefix("deberta");
|
||||
|
||||
match self.task {
|
||||
ArgsTask::NER => Ok((
|
||||
TaskType::NER(DebertaV2NERModel::load(
|
||||
ArgsTask::Ner => Ok((
|
||||
TaskType::Ner(DebertaV2NERModel::load(
|
||||
vb,
|
||||
&config,
|
||||
Some(id2label.clone()),
|
||||
@ -200,7 +200,7 @@ impl Args {
|
||||
|
||||
fn get_device(model_type: &TaskType) -> &Device {
|
||||
match model_type {
|
||||
TaskType::NER(ner_model) => &ner_model.device,
|
||||
TaskType::Ner(ner_model) => &ner_model.device,
|
||||
TaskType::TextClassification(classification_model) => &classification_model.device,
|
||||
}
|
||||
}
|
||||
@ -253,9 +253,9 @@ fn main() -> Result<()> {
|
||||
let mut token_type_id_stack: Vec<Tensor> = Vec::default();
|
||||
|
||||
for encoding in &tokenizer_encodings {
|
||||
encoding_stack.push(Tensor::new(encoding.get_ids(), &device)?);
|
||||
attention_mask_stack.push(Tensor::new(encoding.get_attention_mask(), &device)?);
|
||||
token_type_id_stack.push(Tensor::new(encoding.get_type_ids(), &device)?);
|
||||
encoding_stack.push(Tensor::new(encoding.get_ids(), device)?);
|
||||
attention_mask_stack.push(Tensor::new(encoding.get_attention_mask(), device)?);
|
||||
token_type_id_stack.push(Tensor::new(encoding.get_type_ids(), device)?);
|
||||
}
|
||||
|
||||
ModelInput {
|
||||
@ -272,7 +272,7 @@ fn main() -> Result<()> {
|
||||
);
|
||||
|
||||
match task_type {
|
||||
TaskType::NER(ner_model) => {
|
||||
TaskType::Ner(ner_model) => {
|
||||
if let Some(num_iters) = args.benchmark_iters {
|
||||
create_benchmark(num_iters, model_input)(
|
||||
|input_ids, token_type_ids, attention_mask| {
|
||||
@ -326,7 +326,7 @@ fn main() -> Result<()> {
|
||||
current_row_result.push(NERItem {
|
||||
entity: label,
|
||||
word: current_row_tokens[input_id_idx].clone(),
|
||||
score: current_row_max_scores[input_id_idx].clone(),
|
||||
score: current_row_max_scores[input_id_idx],
|
||||
start: current_row_encoding.get_offsets()[input_id_idx].0,
|
||||
end: current_row_encoding.get_offsets()[input_id_idx].1,
|
||||
index: input_id_idx,
|
||||
|
Reference in New Issue
Block a user