mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Fixes for clippy 1.87. (#2956)
This commit is contained in:
@ -20,8 +20,8 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::{Encoding, PaddingParams, Tokenizer};
|
||||
|
||||
enum TaskType {
|
||||
Ner(DebertaV2NERModel),
|
||||
TextClassification(DebertaV2SeqClassificationModel),
|
||||
Ner(Box<DebertaV2NERModel>),
|
||||
TextClassification(Box<DebertaV2SeqClassificationModel>),
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone, ValueEnum)]
|
||||
@ -169,21 +169,16 @@ impl Args {
|
||||
|
||||
match self.task {
|
||||
ArgsTask::Ner => Ok((
|
||||
TaskType::Ner(DebertaV2NERModel::load(
|
||||
vb,
|
||||
&config,
|
||||
Some(id2label.clone()),
|
||||
)?),
|
||||
TaskType::Ner(DebertaV2NERModel::load(vb, &config, Some(id2label.clone()))?.into()),
|
||||
config,
|
||||
tokenizer,
|
||||
id2label,
|
||||
)),
|
||||
ArgsTask::TextClassification => Ok((
|
||||
TaskType::TextClassification(DebertaV2SeqClassificationModel::load(
|
||||
vb,
|
||||
&config,
|
||||
Some(id2label.clone()),
|
||||
)?),
|
||||
TaskType::TextClassification(
|
||||
DebertaV2SeqClassificationModel::load(vb, &config, Some(id2label.clone()))?
|
||||
.into(),
|
||||
),
|
||||
config,
|
||||
tokenizer,
|
||||
id2label,
|
||||
|
@ -16,8 +16,8 @@ use std::path::PathBuf;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum ModelType {
|
||||
Masked(DistilBertForMaskedLM),
|
||||
UnMasked(DistilBertModel),
|
||||
Masked(Box<DistilBertForMaskedLM>),
|
||||
UnMasked(Box<DistilBertModel>),
|
||||
}
|
||||
|
||||
impl ModelType {
|
||||
@ -144,10 +144,12 @@ impl Args {
|
||||
|
||||
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
|
||||
match self.model {
|
||||
Which::DistilbertForMaskedLM => {
|
||||
Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?))
|
||||
}
|
||||
Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)),
|
||||
Which::DistilbertForMaskedLM => Ok(ModelType::Masked(
|
||||
DistilBertForMaskedLM::load(vb, config)?.into(),
|
||||
)),
|
||||
Which::DistilBert => Ok(ModelType::UnMasked(
|
||||
DistilBertModel::load(vb, config)?.into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user