mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Fixes for clippy 1.87. (#2956)
This commit is contained in:
@ -16,10 +16,9 @@ fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> {
|
|||||||
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
|
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
|
||||||
let magic_number = read_u32(reader)?;
|
let magic_number = read_u32(reader)?;
|
||||||
if magic_number != expected {
|
if magic_number != expected {
|
||||||
Err(io::Error::new(
|
Err(io::Error::other(format!(
|
||||||
io::ErrorKind::Other,
|
"incorrect magic number {magic_number} != {expected}"
|
||||||
format!("incorrect magic number {magic_number} != {expected}"),
|
)))?;
|
||||||
))?;
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -20,8 +20,8 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
|||||||
use tokenizers::{Encoding, PaddingParams, Tokenizer};
|
use tokenizers::{Encoding, PaddingParams, Tokenizer};
|
||||||
|
|
||||||
enum TaskType {
|
enum TaskType {
|
||||||
Ner(DebertaV2NERModel),
|
Ner(Box<DebertaV2NERModel>),
|
||||||
TextClassification(DebertaV2SeqClassificationModel),
|
TextClassification(Box<DebertaV2SeqClassificationModel>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone, ValueEnum)]
|
#[derive(Parser, Debug, Clone, ValueEnum)]
|
||||||
@ -169,21 +169,16 @@ impl Args {
|
|||||||
|
|
||||||
match self.task {
|
match self.task {
|
||||||
ArgsTask::Ner => Ok((
|
ArgsTask::Ner => Ok((
|
||||||
TaskType::Ner(DebertaV2NERModel::load(
|
TaskType::Ner(DebertaV2NERModel::load(vb, &config, Some(id2label.clone()))?.into()),
|
||||||
vb,
|
|
||||||
&config,
|
|
||||||
Some(id2label.clone()),
|
|
||||||
)?),
|
|
||||||
config,
|
config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
id2label,
|
id2label,
|
||||||
)),
|
)),
|
||||||
ArgsTask::TextClassification => Ok((
|
ArgsTask::TextClassification => Ok((
|
||||||
TaskType::TextClassification(DebertaV2SeqClassificationModel::load(
|
TaskType::TextClassification(
|
||||||
vb,
|
DebertaV2SeqClassificationModel::load(vb, &config, Some(id2label.clone()))?
|
||||||
&config,
|
.into(),
|
||||||
Some(id2label.clone()),
|
),
|
||||||
)?),
|
|
||||||
config,
|
config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
id2label,
|
id2label,
|
||||||
|
@ -16,8 +16,8 @@ use std::path::PathBuf;
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
enum ModelType {
|
enum ModelType {
|
||||||
Masked(DistilBertForMaskedLM),
|
Masked(Box<DistilBertForMaskedLM>),
|
||||||
UnMasked(DistilBertModel),
|
UnMasked(Box<DistilBertModel>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelType {
|
impl ModelType {
|
||||||
@ -144,10 +144,12 @@ impl Args {
|
|||||||
|
|
||||||
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
|
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
|
||||||
match self.model {
|
match self.model {
|
||||||
Which::DistilbertForMaskedLM => {
|
Which::DistilbertForMaskedLM => Ok(ModelType::Masked(
|
||||||
Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?))
|
DistilBertForMaskedLM::load(vb, config)?.into(),
|
||||||
}
|
)),
|
||||||
Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)),
|
Which::DistilBert => Ok(ModelType::UnMasked(
|
||||||
|
DistilBertModel::load(vb, config)?.into(),
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -869,8 +869,8 @@ impl Moe {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enum MoeOrMlp {
|
enum MoeOrMlp {
|
||||||
Moe(Moe),
|
Moe(Box<Moe>),
|
||||||
Mlp(Mlp),
|
Mlp(Box<Mlp>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MoeOrMlp {
|
impl MoeOrMlp {
|
||||||
@ -908,14 +908,17 @@ impl DecoderLayer {
|
|||||||
&& layer_idx >= cfg.first_k_dense_replace
|
&& layer_idx >= cfg.first_k_dense_replace
|
||||||
&& layer_idx % cfg.moe_layer_freq == 0
|
&& layer_idx % cfg.moe_layer_freq == 0
|
||||||
{
|
{
|
||||||
MoeOrMlp::Moe(Moe::new(
|
MoeOrMlp::Moe(
|
||||||
cfg,
|
Moe::new(
|
||||||
vb.pp("mlp"),
|
cfg,
|
||||||
cfg.n_shared_experts,
|
vb.pp("mlp"),
|
||||||
cfg.n_routed_experts.unwrap(),
|
cfg.n_shared_experts,
|
||||||
)?)
|
cfg.n_routed_experts.unwrap(),
|
||||||
|
)?
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?)
|
MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?.into())
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
@ -17,8 +17,8 @@ const CROP_NMS_THRESH: f32 = 0.7;
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum ImageEncoder {
|
enum ImageEncoder {
|
||||||
Original(ImageEncoderViT),
|
Original(Box<ImageEncoderViT>),
|
||||||
TinyViT(TinyViT),
|
TinyViT(Box<TinyViT>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for ImageEncoder {
|
impl Module for ImageEncoder {
|
||||||
@ -83,7 +83,7 @@ impl Sam {
|
|||||||
let pixel_std =
|
let pixel_std =
|
||||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
image_encoder: ImageEncoder::Original(image_encoder),
|
image_encoder: ImageEncoder::Original(image_encoder.into()),
|
||||||
prompt_encoder,
|
prompt_encoder,
|
||||||
mask_decoder,
|
mask_decoder,
|
||||||
pixel_std,
|
pixel_std,
|
||||||
@ -114,7 +114,7 @@ impl Sam {
|
|||||||
let pixel_std =
|
let pixel_std =
|
||||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
image_encoder: ImageEncoder::TinyViT(image_encoder),
|
image_encoder: ImageEncoder::TinyViT(image_encoder.into()),
|
||||||
prompt_encoder,
|
prompt_encoder,
|
||||||
mask_decoder,
|
mask_decoder,
|
||||||
pixel_std,
|
pixel_std,
|
||||||
|
@ -134,12 +134,7 @@ impl Scheduler for DDIMScheduler {
|
|||||||
timestep
|
timestep
|
||||||
};
|
};
|
||||||
// https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
|
// https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
|
||||||
let prev_timestep = if timestep > self.step_ratio {
|
let prev_timestep = timestep.saturating_sub(self.step_ratio);
|
||||||
timestep - self.step_ratio
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
|
|
||||||
let alpha_prod_t = self.alphas_cumprod[timestep];
|
let alpha_prod_t = self.alphas_cumprod[timestep];
|
||||||
let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
|
let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
|
||||||
let beta_prod_t = 1. - alpha_prod_t;
|
let beta_prod_t = 1. - alpha_prod_t;
|
||||||
|
Reference in New Issue
Block a user