mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Fixes all clippy warnings
This commit is contained in:
@ -20,14 +20,14 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
|||||||
use tokenizers::{Encoding, PaddingParams, Tokenizer};
|
use tokenizers::{Encoding, PaddingParams, Tokenizer};
|
||||||
|
|
||||||
enum TaskType {
|
enum TaskType {
|
||||||
NER(DebertaV2NERModel),
|
Ner(DebertaV2NERModel),
|
||||||
TextClassification(DebertaV2SeqClassificationModel),
|
TextClassification(DebertaV2SeqClassificationModel),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone, ValueEnum)]
|
#[derive(Parser, Debug, Clone, ValueEnum)]
|
||||||
enum ArgsTask {
|
enum ArgsTask {
|
||||||
/// Named Entity Recognition
|
/// Named Entity Recognition
|
||||||
NER,
|
Ner,
|
||||||
|
|
||||||
/// Text Classification
|
/// Text Classification
|
||||||
TextClassification,
|
TextClassification,
|
||||||
@ -36,7 +36,7 @@ enum ArgsTask {
|
|||||||
impl Display for ArgsTask {
|
impl Display for ArgsTask {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
ArgsTask::NER => write!(f, "ner"),
|
ArgsTask::Ner => write!(f, "ner"),
|
||||||
ArgsTask::TextClassification => write!(f, "text-classification"),
|
ArgsTask::TextClassification => write!(f, "text-classification"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -77,7 +77,7 @@ struct Args {
|
|||||||
benchmark_iters: Option<usize>,
|
benchmark_iters: Option<usize>,
|
||||||
|
|
||||||
/// Which task to run
|
/// Which task to run
|
||||||
#[arg(long, default_value_t = ArgsTask::NER)]
|
#[arg(long, default_value_t = ArgsTask::Ner)]
|
||||||
task: ArgsTask,
|
task: ArgsTask,
|
||||||
|
|
||||||
/// Use model from a specific directory instead of HuggingFace local cache.
|
/// Use model from a specific directory instead of HuggingFace local cache.
|
||||||
@ -142,7 +142,7 @@ impl Args {
|
|||||||
// Command-line id2label takes precedence. Otherwise, use model config's id2label.
|
// Command-line id2label takes precedence. Otherwise, use model config's id2label.
|
||||||
// If neither is specified, then we can't proceed.
|
// If neither is specified, then we can't proceed.
|
||||||
let id2label = if let Some(id2labelstr) = &self.id2label {
|
let id2label = if let Some(id2labelstr) = &self.id2label {
|
||||||
serde_json::from_str(&&id2labelstr.as_str())?
|
serde_json::from_str(id2labelstr.as_str())?
|
||||||
} else if let Some(id2label) = &config.id2label {
|
} else if let Some(id2label) = &config.id2label {
|
||||||
id2label.clone()
|
id2label.clone()
|
||||||
} else {
|
} else {
|
||||||
@ -174,8 +174,8 @@ impl Args {
|
|||||||
let vb = vb.set_prefix("deberta");
|
let vb = vb.set_prefix("deberta");
|
||||||
|
|
||||||
match self.task {
|
match self.task {
|
||||||
ArgsTask::NER => Ok((
|
ArgsTask::Ner => Ok((
|
||||||
TaskType::NER(DebertaV2NERModel::load(
|
TaskType::Ner(DebertaV2NERModel::load(
|
||||||
vb,
|
vb,
|
||||||
&config,
|
&config,
|
||||||
Some(id2label.clone()),
|
Some(id2label.clone()),
|
||||||
@ -200,7 +200,7 @@ impl Args {
|
|||||||
|
|
||||||
fn get_device(model_type: &TaskType) -> &Device {
|
fn get_device(model_type: &TaskType) -> &Device {
|
||||||
match model_type {
|
match model_type {
|
||||||
TaskType::NER(ner_model) => &ner_model.device,
|
TaskType::Ner(ner_model) => &ner_model.device,
|
||||||
TaskType::TextClassification(classification_model) => &classification_model.device,
|
TaskType::TextClassification(classification_model) => &classification_model.device,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -253,9 +253,9 @@ fn main() -> Result<()> {
|
|||||||
let mut token_type_id_stack: Vec<Tensor> = Vec::default();
|
let mut token_type_id_stack: Vec<Tensor> = Vec::default();
|
||||||
|
|
||||||
for encoding in &tokenizer_encodings {
|
for encoding in &tokenizer_encodings {
|
||||||
encoding_stack.push(Tensor::new(encoding.get_ids(), &device)?);
|
encoding_stack.push(Tensor::new(encoding.get_ids(), device)?);
|
||||||
attention_mask_stack.push(Tensor::new(encoding.get_attention_mask(), &device)?);
|
attention_mask_stack.push(Tensor::new(encoding.get_attention_mask(), device)?);
|
||||||
token_type_id_stack.push(Tensor::new(encoding.get_type_ids(), &device)?);
|
token_type_id_stack.push(Tensor::new(encoding.get_type_ids(), device)?);
|
||||||
}
|
}
|
||||||
|
|
||||||
ModelInput {
|
ModelInput {
|
||||||
@ -272,7 +272,7 @@ fn main() -> Result<()> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
match task_type {
|
match task_type {
|
||||||
TaskType::NER(ner_model) => {
|
TaskType::Ner(ner_model) => {
|
||||||
if let Some(num_iters) = args.benchmark_iters {
|
if let Some(num_iters) = args.benchmark_iters {
|
||||||
create_benchmark(num_iters, model_input)(
|
create_benchmark(num_iters, model_input)(
|
||||||
|input_ids, token_type_ids, attention_mask| {
|
|input_ids, token_type_ids, attention_mask| {
|
||||||
@ -326,7 +326,7 @@ fn main() -> Result<()> {
|
|||||||
current_row_result.push(NERItem {
|
current_row_result.push(NERItem {
|
||||||
entity: label,
|
entity: label,
|
||||||
word: current_row_tokens[input_id_idx].clone(),
|
word: current_row_tokens[input_id_idx].clone(),
|
||||||
score: current_row_max_scores[input_id_idx].clone(),
|
score: current_row_max_scores[input_id_idx],
|
||||||
start: current_row_encoding.get_offsets()[input_id_idx].0,
|
start: current_row_encoding.get_offsets()[input_id_idx].0,
|
||||||
end: current_row_encoding.get_offsets()[input_id_idx].1,
|
end: current_row_encoding.get_offsets()[input_id_idx].1,
|
||||||
index: input_id_idx,
|
index: input_id_idx,
|
||||||
|
Reference in New Issue
Block a user