diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 9d0eccdf..70592013 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -86,9 +86,8 @@ impl Args { let config: Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; - let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }; let model = BertModel::load(vb, &config)?; Ok((model, tokenizer)) } diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs index 5f17109e..bf8dd24c 100644 --- a/candle-examples/examples/bigcode/main.rs +++ b/candle-examples/examples/bigcode/main.rs @@ -138,18 +138,9 @@ fn main() -> Result<()> { println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let weights = filenames - .iter() - .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? })) - .collect::>>()?; - let weights = weights - .iter() - .map(|f| Ok(f.deserialize()?)) - .collect::>>()?; - let start = std::time::Instant::now(); let device = candle_examples::device(args.cpu)?; - let vb = VarBuilder::from_safetensors(weights, DType::F32, &device); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let config = Config::starcoder_1b(); let model = GPTBigCode::load(vb, config)?; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs index d3adb37c..6b3edeb4 100644 --- a/candle-examples/examples/dinov2/main.rs +++ b/candle-examples/examples/dinov2/main.rs @@ -42,9 +42,7 @@ pub fn main() -> anyhow::Result<()> { } Some(model) => model.into(), }; - let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? }; - let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; let model = dinov2::vit_small(vb)?; println!("model built"); let logits = model.forward(&image.unsqueeze(0)?)?; diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs index 1e45e301..0e4a2864 100644 --- a/candle-examples/examples/efficientnet/main.rs +++ b/candle-examples/examples/efficientnet/main.rs @@ -68,9 +68,7 @@ pub fn main() -> anyhow::Result<()> { } Some(model) => model.into(), }; - let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? }; - let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; let cfg = match args.which { Which::B0 => MBConvConfig::b0(), Which::B1 => MBConvConfig::b1(), diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index b0973d64..1cef25a8 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -177,21 +177,12 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let weights = filenames - .iter() - .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? })) - .collect::>>()?; - let weights = weights - .iter() - .map(|f| Ok(f.deserialize()?)) - .collect::>>()?; - let dtype = if args.use_f32 { DType::F32 } else { DType::BF16 }; - let vb = VarBuilder::from_safetensors(weights, dtype, &device); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let config = Config::falcon7b(); config.validate()?; let model = Falcon::load(vb, config)?; diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index b2d7d938..4bf91d92 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -172,17 +172,9 @@ fn main() -> Result<()> { } println!("building the model"); - let handles = filenames - .iter() - .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? })) - .collect::>>()?; - let tensors: Vec<_> = handles - .iter() - .map(|h| Ok(h.deserialize()?)) - .collect::>>()?; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; - let vb = VarBuilder::from_safetensors(tensors, dtype, &device); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache) } }; diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index 0fae67b5..a39cfec2 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -73,9 +73,7 @@ fn main() -> Result<()> { )) .get("model.safetensors")?, }; - let model = unsafe { candle::safetensors::MmapedFile::new(model)? }; - let model = model.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DTYPE, &device)? }; let config = GenConfig::small(); let mut model = MusicgenForConditionalGeneration::load(vb, config)?; diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 25c7db98..3b1e7dc1 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -149,18 +149,9 @@ fn main() -> Result<()> { println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let weights = filenames - .iter() - .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? })) - .collect::>>()?; - let weights = weights - .iter() - .map(|f| Ok(f.deserialize()?)) - .collect::>>()?; - let start = std::time::Instant::now(); let device = candle_examples::device(args.cpu)?; - let vb = VarBuilder::from_safetensors(weights, DType::F32, &device); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let config = Config::v1_5(); let model = Model::new(&config, vb)?; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 3d9898b6..71abe116 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -82,9 +82,7 @@ pub fn main() -> anyhow::Result<()> { api.get(filename)? } }; - let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; - let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }; let sam = if args.use_tiny { sam::Sam::new_tiny(vb)? // tiny vit_t } else { diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index c71d562a..0aa4db41 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -481,9 +481,8 @@ fn main() -> Result<()> { let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?; println!("loaded mel: {:?}", mel.dims()); - let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; - let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device); + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? }; let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let mut model = Whisper::load(&vb, config)?; diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index 95f3b8f4..40b43c1d 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -287,10 +287,10 @@ fn run(args: Args) -> Result<()> { )?; let prior = { - let prior_weights = ModelFile::Prior.get(prior_weights)?; - let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? }; - let weights = weights.deserialize()?; - let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + let file = ModelFile::Prior.get(prior_weights)?; + let vb = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)? + }; wuerstchen::prior::WPrior::new( /* c_in */ PRIOR_CIN, /* c */ 1536, @@ -324,10 +324,10 @@ fn run(args: Args) -> Result<()> { println!("Building the vqgan."); let vqgan = { - let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?; - let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? }; - let weights = weights.deserialize()?; - let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + let file = ModelFile::VqGan.get(vqgan_weights)?; + let vb = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)? + }; wuerstchen::paella_vq::PaellaVQ::new(vb)? }; @@ -335,10 +335,10 @@ fn run(args: Args) -> Result<()> { // https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json let decoder = { - let decoder_weights = ModelFile::Decoder.get(decoder_weights)?; - let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? }; - let weights = weights.deserialize()?; - let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + let file = ModelFile::Decoder.get(decoder_weights)?; + let vb = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)? + }; wuerstchen::diffnext::WDiffNeXt::new( /* c_in */ DECODER_CIN, /* c_out */ DECODER_CIN, diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs index ecf75bdf..5b1937ac 100644 --- a/candle-examples/examples/yolo-v3/main.rs +++ b/candle-examples/examples/yolo-v3/main.rs @@ -146,9 +146,7 @@ pub fn main() -> Result<()> { // Create the model and load the weights from the file. let model = args.model()?; - let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; - let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &Device::Cpu)? }; let config = args.config()?; let darknet = darknet::parse_config(config)?; let model = darknet.build_model(vb)?; diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index dc709db4..af8cf98a 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -381,9 +381,7 @@ pub fn run(args: Args) -> anyhow::Result<()> { Which::X => Multiples::x(), }; let model = args.model()?; - let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; - let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }; let model = T::load(vb, multiples)?; println!("model loaded"); for image_name in args.images.iter() { diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index c6f1b904..7fdedaae 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -255,9 +255,8 @@ impl StableDiffusionConfig { device: &Device, dtype: DType, ) -> Result { - let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? }; - let weights = weights.deserialize()?; - let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], dtype, device); + let vs_ae = + unsafe { nn::VarBuilder::from_mmaped_safetensors(&[vae_weights], dtype, device)? }; // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?; Ok(autoencoder) @@ -271,9 +270,8 @@ impl StableDiffusionConfig { use_flash_attn: bool, dtype: DType, ) -> Result { - let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? }; - let weights = weights.deserialize()?; - let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], dtype, device); + let vs_unet = + unsafe { nn::VarBuilder::from_mmaped_safetensors(&[unet_weights], dtype, device)? }; let unet = unet_2d::UNet2DConditionModel::new( vs_unet, in_channels, @@ -295,9 +293,7 @@ pub fn build_clip_transformer>( device: &Device, dtype: DType, ) -> Result { - let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? }; - let weights = weights.deserialize()?; - let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device); + let vs = unsafe { nn::VarBuilder::from_mmaped_safetensors(&[clip_weights], dtype, device)? }; let text_model = clip::ClipTextTransformer::new(vs, clip)?; Ok(text_model) }