diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 968ecd0d..9c9dc206 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] -use anyhow::{Error as E, Result}; +use anyhow::{anyhow, Error as E, Result}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; -use candle_hub::{api::Api, Repo, RepoType}; +use candle_hub::{api::Api, Cache, Repo, RepoType}; use clap::Parser; use serde::Deserialize; use std::collections::HashMap; @@ -617,6 +617,10 @@ struct Args { #[arg(long)] cpu: bool, + /// Run offline (you must have the files already cached) + #[arg(long)] + offline: bool, + #[arg(long)] model_id: Option, @@ -627,6 +631,8 @@ struct Args { #[tokio::main] async fn main() -> Result<()> { use tokenizers::Tokenizer; + let start = std::time::Instant::now(); + println!("Building {:?}", start.elapsed()); let args = Args::parse(); let device = if args.cpu { @@ -644,21 +650,41 @@ async fn main() -> Result<()> { (None, None) => (default_model, default_revision), }; - let api = Api::new()?; let repo = Repo::with_revision(model_id, RepoType::Model, revision); - println!("building the model"); - let config_filename = api.get(&repo, "config.json").await?; + let (config_filename, tokenizer_filename, weights_filename) = if args.offline { + let cache = Cache::default(); + ( + cache + .get(&repo, "config.json") + .ok_or(anyhow!("Missing config file in cache"))?, + cache + .get(&repo, "tokenizer.json") + .ok_or(anyhow!("Missing tokenizer file in cache"))?, + cache + .get(&repo, "model.safetensors") + .ok_or(anyhow!("Missing weights file in cache"))?, + ) + } else { + let api = Api::new()?; + ( + api.get(&repo, "config.json").await?, + api.get(&repo, "tokenizer.json").await?, + api.get(&repo, "model.safetensors").await?, + ) + }; + println!("Building {:?}", start.elapsed()); let config = std::fs::read_to_string(config_filename)?; let config: Config = serde_json::from_str(&config)?; - let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; + println!("Config loaded {:?}", start.elapsed()); let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = tokenizer.with_padding(None).with_truncation(None); + println!("Tokenizer loaded {:?}", start.elapsed()); - let weights = api.get(&repo, "model.safetensors").await?; - let weights = unsafe { candle::safetensors::MmapedFile::new(weights)? }; + let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone()); let model = BertModel::load(&vb, &config)?; + println!("Loaded {:?}", start.elapsed()); let tokens = tokenizer .encode("This is an example sentence", true) @@ -666,9 +692,13 @@ async fn main() -> Result<()> { .get_ids() .to_vec(); let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; - println!("{token_ids}"); let token_type_ids = token_ids.zeros_like()?; - let ys = model.forward(&token_ids, &token_type_ids)?; - println!("{ys}"); + println!("Loaded and encoded {:?}", start.elapsed()); + for _ in 0..100 { + let start = std::time::Instant::now(); + let _ys = model.forward(&token_ids, &token_type_ids)?; + println!("Took {:?}", start.elapsed()); + // println!("Ys {:?}", ys.shape()); + } Ok(()) }