diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs index eb79e17e..b8eaf99c 100644 --- a/candle-datasets/src/vision/mnist.rs +++ b/candle-datasets/src/vision/mnist.rs @@ -16,10 +16,9 @@ fn read_u32(reader: &mut T) -> std::io::Result { fn check_magic_number(reader: &mut T, expected: u32) -> Result<()> { let magic_number = read_u32(reader)?; if magic_number != expected { - Err(io::Error::new( - io::ErrorKind::Other, - format!("incorrect magic number {magic_number} != {expected}"), - ))?; + Err(io::Error::other(format!( + "incorrect magic number {magic_number} != {expected}" + )))?; } Ok(()) } diff --git a/candle-examples/examples/debertav2/main.rs b/candle-examples/examples/debertav2/main.rs index b1938038..2f5f3ff2 100644 --- a/candle-examples/examples/debertav2/main.rs +++ b/candle-examples/examples/debertav2/main.rs @@ -20,8 +20,8 @@ use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::{Encoding, PaddingParams, Tokenizer}; enum TaskType { - Ner(DebertaV2NERModel), - TextClassification(DebertaV2SeqClassificationModel), + Ner(Box), + TextClassification(Box), } #[derive(Parser, Debug, Clone, ValueEnum)] @@ -169,21 +169,16 @@ impl Args { match self.task { ArgsTask::Ner => Ok(( - TaskType::Ner(DebertaV2NERModel::load( - vb, - &config, - Some(id2label.clone()), - )?), + TaskType::Ner(DebertaV2NERModel::load(vb, &config, Some(id2label.clone()))?.into()), config, tokenizer, id2label, )), ArgsTask::TextClassification => Ok(( - TaskType::TextClassification(DebertaV2SeqClassificationModel::load( - vb, - &config, - Some(id2label.clone()), - )?), + TaskType::TextClassification( + DebertaV2SeqClassificationModel::load(vb, &config, Some(id2label.clone()))? + .into(), + ), config, tokenizer, id2label, diff --git a/candle-examples/examples/distilbert/main.rs b/candle-examples/examples/distilbert/main.rs index c9c178d6..7f9df7cf 100644 --- a/candle-examples/examples/distilbert/main.rs +++ b/candle-examples/examples/distilbert/main.rs @@ -16,8 +16,8 @@ use std::path::PathBuf; use tokenizers::Tokenizer; enum ModelType { - Masked(DistilBertForMaskedLM), - UnMasked(DistilBertModel), + Masked(Box), + UnMasked(Box), } impl ModelType { @@ -144,10 +144,12 @@ impl Args { fn create_model(&self, config: &Config, vb: VarBuilder) -> Result { match self.model { - Which::DistilbertForMaskedLM => { - Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?)) - } - Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)), + Which::DistilbertForMaskedLM => Ok(ModelType::Masked( + DistilBertForMaskedLM::load(vb, config)?.into(), + )), + Which::DistilBert => Ok(ModelType::UnMasked( + DistilBertModel::load(vb, config)?.into(), + )), } } } diff --git a/candle-transformers/src/models/deepseek2.rs b/candle-transformers/src/models/deepseek2.rs index 16c6907a..6a418b43 100644 --- a/candle-transformers/src/models/deepseek2.rs +++ b/candle-transformers/src/models/deepseek2.rs @@ -869,8 +869,8 @@ impl Moe { } enum MoeOrMlp { - Moe(Moe), - Mlp(Mlp), + Moe(Box), + Mlp(Box), } impl MoeOrMlp { @@ -908,14 +908,17 @@ impl DecoderLayer { && layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 { - MoeOrMlp::Moe(Moe::new( - cfg, - vb.pp("mlp"), - cfg.n_shared_experts, - cfg.n_routed_experts.unwrap(), - )?) + MoeOrMlp::Moe( + Moe::new( + cfg, + vb.pp("mlp"), + cfg.n_shared_experts, + cfg.n_routed_experts.unwrap(), + )? + .into(), + ) } else { - MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?) + MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?.into()) }; Ok(Self { diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs index a2156a75..7a84eef4 100644 --- a/candle-transformers/src/models/segment_anything/sam.rs +++ b/candle-transformers/src/models/segment_anything/sam.rs @@ -17,8 +17,8 @@ const CROP_NMS_THRESH: f32 = 0.7; #[derive(Debug)] enum ImageEncoder { - Original(ImageEncoderViT), - TinyViT(TinyViT), + Original(Box), + TinyViT(Box), } impl Module for ImageEncoder { @@ -83,7 +83,7 @@ impl Sam { let pixel_std = Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; Ok(Self { - image_encoder: ImageEncoder::Original(image_encoder), + image_encoder: ImageEncoder::Original(image_encoder.into()), prompt_encoder, mask_decoder, pixel_std, @@ -114,7 +114,7 @@ impl Sam { let pixel_std = Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; Ok(Self { - image_encoder: ImageEncoder::TinyViT(image_encoder), + image_encoder: ImageEncoder::TinyViT(image_encoder.into()), prompt_encoder, mask_decoder, pixel_std, diff --git a/candle-transformers/src/models/stable_diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs index ae2b40db..d8ef5ec9 100644 --- a/candle-transformers/src/models/stable_diffusion/ddim.rs +++ b/candle-transformers/src/models/stable_diffusion/ddim.rs @@ -134,12 +134,7 @@ impl Scheduler for DDIMScheduler { timestep }; // https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195 - let prev_timestep = if timestep > self.step_ratio { - timestep - self.step_ratio - } else { - 0 - }; - + let prev_timestep = timestep.saturating_sub(self.step_ratio); let alpha_prod_t = self.alphas_cumprod[timestep]; let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]; let beta_prod_t = 1. - alpha_prod_t;